/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver on 8/4/2018. // #include "testlayers.h" #include #include #include #include #include #include using namespace nd4j; class DeclarableOpsTests12 : public testing::Test { public: DeclarableOpsTests12() { printf("\n"); fflush(stdout); } }; TEST_F(DeclarableOpsTests12, test_any_validation_1) { auto x = NDArrayFactory::create('c', {2, 1}, {1.0, 2.0}); auto y = NDArrayFactory::create('c', {2}, {1, 0}); nd4j::ops::transpose op; auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_EQ(x.dataType(), z->dataType()); delete result; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) { NDArray labels('c', {2,4}, {0,1,1,0,1,0,1,0}); NDArray predictions('c', {2,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {2,1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {2,4}, {-0. , -0.5, -0.5, -0., -0.5, -0. , -0.5, -0.}); NDArray dLdwExp('c', {2,1}, {1.2, -0.2}); predictions.linspace(-0.4, 0.2); weights.assign(0.5); nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {0, -1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) { NDArray labels('c', {2,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2}); NDArray predictions('c', {2,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,4}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {2,4}, {0.05, -0.15, -1. , 0.7 ,-1.25, 1.5 , -0.6 , -1.1 }); NDArray dLdwExp('c', {1,4}, {-0.04, 2.86, 0.04, -0.92}); NDArray dLdlExp('c', {2,4}, {0.2, 0.1, 0. , -0.1, -0.2, -0.3, -0.4, -0.5}); predictions.linspace(-0.4, 0.2); weights.assign(0.5); nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) { NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}); NDArray predictions('c', {4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); NDArray dLdwExp('c', {1}, {1.3}); NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1}); predictions.linspace(-0.4, 0.2); weights.assign(0.5); nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); NDArray dLdwExp('c', {}, {1.3}); NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); predictions.linspace(-0.4, 0.2); weights.assign(0.5); nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) { NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4}); NDArray dLdwExp('c', {1,1}, {0.}); NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2}); predictions.linspace(-0.4, 0.2); weights = 0.5; nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) { NDArray labels('c', {4,1}, {-0.1, 0.3, 2, -1.4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {4,1}, nd4j::DataType::DOUBLE); NDArray weights('c', {4,1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4,1}, {0.0125, -0.0375, -0.25 , 0.175}); NDArray dLdwExp('c', {4,1}, {0.24 , 0.265, 0.25 , 0.32}); NDArray dLdlExp('c', {4,1}, {0.05 , 0.025, -0. , -0.025}); predictions.linspace(-0.4, 0.2); weights = 0.5; nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) { NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {2,3,4}, {0.00833, -0.025 , -0.16667, 0.11667,-0.20833, 0.25 , -0.1 , -0.18333, 0.00833, -0.025 , -0.16667, 0.28333, -0.20833, 0.25 , -0.1 , -0.18333, 0.01667, -0.025 , -0.16667, 0.11667,-0.225 , 0.25 , -0.1 , -0.35 }); NDArray dLdwExp('c', {1,3,1}, {0.50444, 0.89778, -1.40222}); NDArray dLdlExp('c', {2,3,4}, {0.03333, 0.01667, -0. , -0.01667,-0.03333, -0.05 , -0.06667, -0.08333,-0.1, -0.11667, -0.13333, -0.15, -0.16667, -0.18333, -0.2 , -0.21667,-0.23333, -0.25 , -0.26667, -0.28333,-0.3, -0.31667, -0.33333, -0.35 }); predictions.linspace(-0.4, 0.2); weights = 0.5; nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) { NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {2,1,1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {2,3,4}, {0.00625, -0.01875, -0.125 , 0.0875,-0.15625, 0.1875 , -0.075 , -0.1375, 0.00625, -0.01875, -0.125 , 0.2125, -0.15625, 0.1875 , -0.075 , -0.1375, 0.0125 , -0.01875, -0.125 , 0.0875,-0.16875, 0.1875 , -0.075 , -0.2625}); NDArray dLdwExp('c', {2,1,1}, {0.57, -3.2175}); NDArray dLdlExp('c', {2,3,4}, {0.025, 0.0125, -0. , -0.0125,-0.025, -0.0375, -0.05, -0.0625,-0.075, -0.0875, -0.1 , -0.1125, -0.125, -0.1375, -0.15, -0.1625,-0.175, -0.1875, -0.2 , -0.2125,-0.225, -0.2375, -0.25, -0.2625}); predictions.linspace(-0.4, 0.2); weights = 0.5; nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {2,3,4}, {0.05, -0.15, -1. , 0.7,-1.25, 1.5 , -0.6 , -1.1, 0.05, -0.15, -1. , 1.7, -1.25, 1.5 , -0.6 , -1.1, 0.1 , -0.15, -1. , 0.7,-1.35, 1.5 , -0.6 , -2.1}); NDArray dLdwExp('c', {2,3,1}, {1.3 , -1.36, 3.62, -6. , -0.98,-19.76}); NDArray dLdlExp('c', {2,3,4}, {0.2, 0.1, -0. , -0.1,-0.2, -0.3, -0.4, -0.5,-0.6, -0.7, -0.8, -0.9, -1. , -1.1, -1.2, -1.3,-1.4, -1.5, -1.6, -1.7,-1.8, -1.9, -2. , -2.1}); predictions.linspace(-0.4, 0.2); weights = 0.5; nd4j::ops::cosine_distance_loss_grad op; auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *dLdp = results->at(0); auto *dLdw = results->at(1); auto *dLdl = results->at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); delete results; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, hinge_loss_14) { NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, {1.}); NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE); logits.linspace(1.); weights.assign(1.); nd4j::ops::hinge_loss op; Nd4jStatus status = op.execute({&logits, &weights, &labels}, {&output}, {}, {1}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestDivideBP_1) { NDArray x('c', {3,4}, nd4j::DataType::DOUBLE); NDArray y = NDArrayFactory::create(2.); NDArray eps('c', {3,4}, nd4j::DataType::DOUBLE); NDArray output1('c', {3, 4}, nd4j::DataType::DOUBLE); NDArray output2(nd4j::DataType::DOUBLE); x.linspace(2., 2.); eps.linspace(1.); nd4j::ops::divide_bp op; Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output1.printIndexedBuffer("DivideBP X out"); output2.printIndexedBuffer("DivideBP Y out"); //ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestDivideBP_2) { NDArray x('c', {3,4}, nd4j::DataType::DOUBLE); NDArray y = NDArrayFactory::create('c', {3,4}); NDArray eps('c', {3,4}, nd4j::DataType::DOUBLE); NDArray exp1('c', {3,4}, nd4j::DataType::DOUBLE); NDArray exp2('c', {3,4}, nd4j::DataType::DOUBLE); NDArray output1('c', {3, 4}, nd4j::DataType::DOUBLE); NDArray output2('c', {3, 4}, nd4j::DataType::DOUBLE); exp1.assign(1.); exp2.assign(-2.); x.linspace(2., 2.); y.linspace(1.); eps.linspace(1.); nd4j::ops::divide_bp op; Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output1.printIndexedBuffer("2DivideBP X out"); output2.printIndexedBuffer("2DivideBP Y out"); ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestReverseDivideBP_1) { NDArray x('c', {3,4}, nd4j::DataType::DOUBLE); NDArray y = NDArrayFactory::create(2.); NDArray eps('c', {3,4}, nd4j::DataType::DOUBLE); NDArray output1('c', {3, 4}, nd4j::DataType::DOUBLE); NDArray output2(nd4j::DataType::DOUBLE); x.linspace(2., 2.); eps.linspace(1.); nd4j::ops::reversedivide_bp op; Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output1.printIndexedBuffer("RDivideBP X out"); output2.printIndexedBuffer("RDivideBP Y out"); //ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestReverseDivideBP_2) { NDArray x('c', {3,4}, nd4j::DataType::DOUBLE); NDArray y = NDArrayFactory::create('c', {3,4}); NDArray eps('c', {3,4}, nd4j::DataType::DOUBLE); NDArray exp1('c', {3,4}, nd4j::DataType::DOUBLE); NDArray exp2('c', {3,4}, nd4j::DataType::DOUBLE); NDArray output1('c', {3, 4}, nd4j::DataType::DOUBLE); NDArray output2('c', {3, 4}, nd4j::DataType::DOUBLE); x.linspace(2., 2.); y.linspace(1.); eps.linspace(1.); exp1.assign(1.); exp2.assign(-2.); nd4j::ops::reversedivide_bp op; Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output1.printIndexedBuffer("2RDivideBP X out"); output2.printIndexedBuffer("2RDivideBP Y out"); ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestSliceBP_1) { NDArray x('c', {3,4}, nd4j::DataType::DOUBLE); NDArray eps('c', {2,2}, nd4j::DataType::DOUBLE); NDArray exp('c', {3,4}, {0., 0., 0., 0., 0., 1.,1., 0., 0., 1., 1., 0.}); //NDArray exp2('c', {3,4}, nd4j::DataType::DOUBLE); NDArray output('c', {3, 4}, nd4j::DataType::DOUBLE); //NDArray output2('c', {3, 4}, nd4j::DataType::DOUBLE); output.assign(119.113); x.linspace(1.); eps.assign(1.); //exp1.assign(1.); //exp2.assign(-2.); nd4j::ops::slice_bp op; Nd4jStatus status = op.execute({&x, &eps}, {&output}, {}, {1,1,2,2}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output.printIndexedBuffer("SLICE_BP out"); ASSERT_TRUE(output.equalsTo(exp)); //ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestConfusionZero_1) { NDArray x('c', {2}, {1,2}, nd4j::DataType::INT64); NDArray i('c', {2}, {0,2}, nd4j::DataType::INT64); //NDArray eps('c', {2,2}, nd4j::DataType::DOUBLE); NDArray exp('c', {4,4}, {0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, nd4j::DataType::INT64); //NDArray exp2('c', {3,4}, nd4j::DataType::DOUBLE); NDArray output('c', {4, 4}, nd4j::DataType::INT64); //NDArray output2('c', {3, 4}, nd4j::DataType::DOUBLE); output.assign(119.113); x.linspace(1.); //eps.assign(1.); //exp1.assign(1.); //exp2.assign(-2.); nd4j::ops::confusion_matrix op; Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output.printIndexedBuffer("Confusion out"); ASSERT_TRUE(output.equalsTo(exp)); //ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestMaximumBP_1) { NDArray x('c', {3,4}, nd4j::DataType::DOUBLE); NDArray y('c', {3,4}, nd4j::DataType::DOUBLE); NDArray eps('c', {3,4}, nd4j::DataType::DOUBLE); NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, nd4j::DataType::DOUBLE); NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, nd4j::DataType::DOUBLE); NDArray output1('c', {3, 4}, nd4j::DataType::DOUBLE); NDArray output2('c', {3, 4}, nd4j::DataType::DOUBLE); output1.assign(119); x.linspace(1.); y.linspace(12., -1.); x.printBuffer("X"); y.printBuffer("Y"); eps.linspace(1.); //exp1.assign(1.); //exp2.assign(-2.); nd4j::ops::maximum_bp op; Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output1.printIndexedBuffer("X max"); output2.printIndexedBuffer("Y max"); ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestMinimumBP_1) { NDArray x('c', {3,4}, nd4j::DataType::DOUBLE); NDArray y('c', {3,4}, nd4j::DataType::DOUBLE); NDArray eps('c', {3,4}, nd4j::DataType::DOUBLE); NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, nd4j::DataType::DOUBLE); NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, nd4j::DataType::DOUBLE); NDArray output1('c', {3, 4}, nd4j::DataType::DOUBLE); NDArray output2('c', {3, 4}, nd4j::DataType::DOUBLE); output1.assign(119); x.linspace(1.); y.linspace(12., -1.); x.printBuffer("X"); y.printBuffer("Y"); eps.linspace(1.); //exp1.assign(1.); //exp2.assign(-2.); nd4j::ops::minimum_bp op; Nd4jStatus status = op.execute({&x, &y, &eps}, {&output2, &output1}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); output2.printIndexedBuffer("X min"); output1.printIndexedBuffer("Y min"); ASSERT_TRUE(output1.equalsTo(exp1)); ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reverse_test15) { NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE); NDArray axis('c', {}, {0}, nd4j::DataType::INT32); NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE); nd4j::ops::reverse op; // auto result = op.execute({&x, &axis}, {}, {1}, {}); Nd4jStatus status = op.execute({&x, &axis}, {&z}, {}, {1}, {}); // auto z = result->at(0); // z->printIndexedBuffer(); ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); // delete result; } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, mirrorPad_test17) { NDArray x('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT32); NDArray z('c', {4,7}, nd4j::DataType::DOUBLE); NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, nd4j::DataType::DOUBLE); NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, nd4j::DataType::DOUBLE); nd4j::ops::mirror_pad op; Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(exp1.isSameShape(z)); ASSERT_TRUE(exp1.equalsTo(z)); z = 0.; status = op.execute({&x, &padding}, {&z}, {}, {1}, {}); // symmetric ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(exp2.isSameShape(z)); ASSERT_TRUE(exp2.equalsTo(z)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, mirrorPad_test18) { NDArray x('c', {3}, {1,2,3}, nd4j::DataType::DOUBLE); NDArray padding('c', {2}, {1,1}, nd4j::DataType::INT32); NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {2,1,2,3,2}, nd4j::DataType::DOUBLE); nd4j::ops::mirror_pad op; Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, relu_1) { NDArray input('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, -0.557449, -0.768277, -1.094015,0.563735, 0.900299, 0.789979, -0.563735, -0.900299, -0.789979, 0.142528, 0.959611, 0.877506, -0.142528, -0.959611, -0.877506,0.448742, 0.995377, 1.171543, -0.448742, -0.995377, -1.171543, 0.603772, 0.799391, 0.560310, -0.603772, -0.799391, -0.560310,0.529753, 0.906786, 0.737630, -0.529753, -0.906786, -0.737630, 0.221464, 0.824996, 0.472221, -0.221464, -0.824996, -0.472221,0.427730, 0.397933, 0.714365, -0.427730, -0.397933, -0.714365, 0.488365, 1.016589, 0.744197, -0.488365, -1.016589, -0.744197,0.789846, 0.940837, 0.838412, -0.789846, -0.940837, -0.838412, 0.404485, 0.677328, 0.754997, -0.404485, -0.677328, -0.754997,0.436760, 0.794765, 0.729766, -0.436760, -0.794765, -0.729766, 0.588081, 0.652226, 0.725522, -0.588081, -0.652226, -0.725522,0.374457, 1.225813, 1.053411, -0.374457, -1.225813, -1.053411, 0.300958, 0.599417, 0.633234, -0.300958, -0.599417, -0.633234,0.241993, 1.025464, 0.695378, -0.241993, -1.025464, -0.695378, 0.236289, 0.907919, 1.012100, -0.236289, -0.907919, -1.012100,0.627402, 0.565187, 0.766926, -0.627402, -0.565187, -0.766926, 0.133276, 0.326284, 0.102804, -0.133276, -0.326284, -0.102804,0.426913, 0.256251, 0.305241, -0.426913, -0.256251, -0.305241, 0.177977, 0.841799, 0.800615, -0.177977, -0.841799, -0.800615,0.001991, 0.518389, 0.439322, -0.001991, -0.518389, -0.439322, 0.166846, 0.508224, 0.486687, -0.166846, -0.508224, -0.486687,0.167493, 0.930932, 0.868717, -0.167493, -0.930932, -0.868717, 0.174864, 0.444607, 0.445000, -0.174864, -0.444607, -0.445000}, nd4j::DataType::FLOAT32); NDArray expected('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, 0., 0., 0., 0.563735, 0.900299, 0.789979, 0., 0., 0., 0.142528, 0.959611, 0.877506, 0., 0., 0., 0.448742, 0.995377, 1.171543, 0., 0., 0., 0.603772, 0.799391, 0.560310, 0., 0., 0., 0.529753, 0.906786, 0.737630, 0., 0., 0., 0.221464, 0.824996, 0.472221, 0., 0., 0., 0.427730, 0.397933, 0.714365, 0., 0., 0., 0.488365, 1.016589, 0.744197, 0., 0., 0., 0.789846, 0.940837, 0.838412, 0., 0., 0., 0.404485, 0.677328, 0.754997, 0., 0., 0., 0.436760, 0.794765, 0.729766, 0., 0., 0., 0.588081, 0.652226, 0.725522, 0., 0., 0., 0.374457, 1.225813, 1.053411, 0., 0., 0., 0.300958, 0.599417, 0.633234, 0., 0., 0., 0.241993, 1.025464, 0.695378, 0., 0., 0., 0.236289, 0.907919, 1.012100, 0., 0., 0., 0.627402, 0.565187, 0.766926, 0., 0., 0., 0.133276, 0.326284, 0.102804, 0., 0., 0., 0.426913, 0.256251, 0.305241, 0., 0., 0., 0.177977, 0.841799, 0.800615, 0., 0., 0., 0.001991, 0.518389, 0.439322, 0., 0., 0., 0.166846, 0.508224, 0.486687, 0., 0., 0., 0.167493, 0.930932, 0.868717, 0., 0., 0., 0.174864, 0.444607, 0.445000, 0., 0., 0.}, nd4j::DataType::FLOAT32); NDArray z('c', {1,5,5,6}, nd4j::DataType::FLOAT32); nd4j::ops::relu op; Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(expected.isSameShapeStrict(&z)); ASSERT_TRUE(expected.equalsTo(z)); } #include "ops/declarable/helpers/multiUnique.h" //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, multiUnique_1) { NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, nd4j::DataType::INT32); NDArray input2('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12}, nd4j::DataType::INT32); NDArray input3('c', {2,3}, {10,11,12,13,14,15}, nd4j::DataType::INT32); NDArray input4('c', {1,5}, {7,8,9,10,11}, nd4j::DataType::INT32); NDArray input5('c', {5,3}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, nd4j::DataType::INT32); //NDArray indices('c', {1}, {2}, nd4j::DataType::INT32); //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, nd4j::DataType::FLOAT32); std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); ASSERT_FALSE(nd4j::ops::helpers::multiUnique(arrayList)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, multiUnique_2) { NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, nd4j::DataType::INT32); NDArray input2('c', {3,4}, {21,22,23,24,25,26,27,28,29,210,211,212}, nd4j::DataType::INT32); NDArray input3('c', {2,3}, {310,311,312,313,314,315}, nd4j::DataType::INT32); NDArray input4('c', {1,5}, {47,48,49,410,411}, nd4j::DataType::INT32); NDArray input5('c', {5,3}, {51,52,53,54,55,56,57,58,59,510,511,512,513,514,515}, nd4j::DataType::INT32); //NDArray indices('c', {1}, {2}, nd4j::DataType::INT32); //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, nd4j::DataType::FLOAT32); std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); ASSERT_TRUE(nd4j::ops::helpers::multiUnique(arrayList)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, tensormmul_6) { NDArray x('c', {1}, {2}); NDArray y('c', {2,1,2}, {1,2,3,4}); NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {1,0, 1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *result = results->at(0); // exp.printShapeInfo(); // result->printShapeInfo(); // result->printIndexedBuffer(); ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, concat_test10) { NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32); NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32); NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32); x0 = 0.; x1 = 1.; nd4j::ops::concat op; auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, concat_14) { NDArray x0('c', {1,6}, {1,2,3,4,5,6}); NDArray x1('c', {1,6}, {7,8,9,10,11,12}); NDArray output('f', {2,6}, nd4j::DataType::DOUBLE); NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}); nd4j::ops::concat op; auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); // output.printBuffer(); // output.printIndexedBuffer(); ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, concat_15) { NDArray x0('c', {1,4}, {1,2,3,4}); NDArray x1('c', {1,4}, {5,6,7,8}); NDArray output('c', {2,4}, nd4j::DataType::DOUBLE); NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}); nd4j::ops::concat op; auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); // output.printBuffer(); // output.printIndexedBuffer(); ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); NDArray gradO('c', {5}, nd4j::DataType::DOUBLE); NDArray exp('c', {3,5}, nd4j::DataType::DOUBLE); gradO = 1.; exp = 0.333333; nd4j::ops::reduce_mean_bp op; auto result = op.execute({&x, &gradO}, {}, {0}); auto output = result->at(0); // output->printShapeInfo(); // output->printIndexedBuffer(); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_5) { NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); NDArray gradO('c', {3}, nd4j::DataType::DOUBLE); NDArray exp('c', {3,5}, nd4j::DataType::DOUBLE); gradO = 1.; exp = 0.2; nd4j::ops::reduce_mean_bp op; auto result = op.execute({&x, &gradO}, {}, {1}); auto output = result->at(0); // output->printShapeInfo(); // output->printIndexedBuffer(); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceSqnormBp_1) { NDArray x('c', {8,6,4}, nd4j::DataType::DOUBLE); NDArray gradO('c', {8,6,1}, nd4j::DataType::DOUBLE); nd4j::ops::reduce_sqnorm_bp op; auto result = op.execute({&x, &gradO}, {1}, {2}); ASSERT_EQ(Status::OK(), result->status()); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cumsum_1) { NDArray x('f', {3, 4}, nd4j::DataType::FLOAT32); nd4j::ops::cumsum op; auto result = op.execute({&x}, {}, {0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printShapeInfo(); // x.printShapeInfo(); ASSERT_TRUE(z->ews() == 1); ASSERT_TRUE(x.ews() == 1); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pullRows_1) { NDArray x('c', {5, 1}, {0,1,2,3,4}); NDArray z('c', {4, 1}, nd4j::DataType::DOUBLE); NDArray exp('c', {4, 1}, {0,2,3,4}); Nd4jLong indexes[] = {0,2,3,4}; PointersManager pm(LaunchContext::defaultContext(), "pullRows"); auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); std::vector dims = {1}; auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims); auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims); NativeOps op; Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ nativeStart[1] = *(x.getContext()->getCudaStream()); #endif op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); ASSERT_TRUE(z.equalsTo(exp)); pm.synchronize(); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pullRows_2) { NDArray arr('f', {5, 2}, {0,1,2,3,4,5,6,7,8,9}); NDArray* y = arr.dup('c'); NDArray x = (*y)({0,0, 0,1}, true); // view, points on first column of y, shape is {5,1} NDArray z('c', {4, 1}, nd4j::DataType::DOUBLE); NDArray exp('c', {4, 1}, {0,2,3,4}); Nd4jLong indexes[] = {0,2,3,4}; PointersManager pm(LaunchContext::defaultContext(), "pullRows"); auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); std::vector dims = {1}; auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims); auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims); NativeOps op; Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ nativeStart[1] = *(x.getContext()->getCudaStream()); #endif op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); ASSERT_TRUE(z.equalsTo(exp)); pm.synchronize(); delete y; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, softmax_9) { NDArray arrC('c', {5,2}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, nd4j::DataType::FLOAT32); NDArray* arrF = arrC.dup('f'); NDArray outCC('c', {5,2}, nd4j::DataType::FLOAT32); NDArray outCF('f', {5,2}, nd4j::DataType::FLOAT32); NDArray outFC('c', {5,2}, nd4j::DataType::FLOAT32); NDArray outFF('c', {5,2}, nd4j::DataType::FLOAT32); nd4j::ops::softmax op; auto status1 = op.execute({&arrC}, {&outCC}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status1); auto status2 = op.execute({&arrC}, {&outCF}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status2); auto status3 = op.execute({arrF}, {&outFC}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status3); auto status4 = op.execute({arrF}, {&outFF}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status4); // outCC.printIndexedBuffer("\n"); // outCF.printIndexedBuffer("\n"); // outFC.printIndexedBuffer("\n"); // outFF.printIndexedBuffer("\n"); ASSERT_EQ(outCC, outCF); ASSERT_EQ(outCC, outFC); ASSERT_EQ(outCC, outFF); delete arrF; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_1) { NDArray input('c', {2,3,4,10}); NDArray gradO('c', {2,3,4,10}); NDArray exp('c', {2,3,4,10}, {1.00438418e-02, 5.25184907e-03, 1.78685773e-03, -1.14537543e-03, -4.00071684e-03, -5.31899510e-03, -4.97647980e-03, -4.42161644e-03, -3.95395281e-03, -3.59310722e-03, 2.91823584e-04, -2.18498681e-05, -3.12092161e-04, -6.07360795e-04, -9.36298165e-04, -1.02553482e-03, -7.91735307e-04, -6.15672267e-04, -4.71792649e-04, -3.42114770e-04, 4.29357824e-05, -5.46473675e-05, -1.48361753e-04, -2.47166492e-04, -3.61090642e-04, -3.81607766e-04, -2.89086485e-04, -2.17203109e-04, -1.56231865e-04, -9.91634734e-05, 8.99407951e-06, -3.76849275e-05, -8.32021178e-05, -1.31939698e-04, -1.89008832e-04, -1.96661276e-04, -1.47534331e-04, -1.08789405e-04, -7.53896020e-05, -4.36357586e-05, 1.23124300e-06, -2.60028974e-05, -5.27824741e-05, -8.17063192e-05, -1.15871291e-04, -1.19515295e-04, -8.91248055e-05, -6.49499125e-05, -4.39216528e-05, -2.37579407e-05, -9.34046056e-07, -1.87477999e-05, -3.63574763e-05, -5.54830040e-05, -7.82010393e-05, -8.02115537e-05, -5.95739621e-05, -4.30659420e-05, -2.86241393e-05, -1.47010251e-05, -1.52835810e-06, -1.40790498e-05, -2.65316012e-05, -4.01083526e-05, -5.62983550e-05, -5.75223821e-05, -4.25982689e-05, -3.06141737e-05, -2.00884024e-05, -9.90276021e-06, -1.61666367e-06, -1.09328157e-05, -2.02010433e-05, -3.03347279e-05, -4.24536738e-05, -4.32532870e-05, -3.19610226e-05, -2.28673853e-05, -1.48570880e-05, -7.08444895e-06, -1.53552355e-06, -8.72318924e-06, -1.58886232e-05, -2.37402273e-05, -3.31507035e-05, -3.37014644e-05, -2.48602537e-05, -1.77248403e-05, -1.14254890e-05, -5.30027773e-06, -1.40318230e-06, -7.11624580e-06, -1.28209140e-05, -1.90826468e-05, -2.66006646e-05, -2.69959855e-05, -1.98865000e-05, -1.41387427e-05, -9.05554589e-06, -4.10473058e-06, -1.26330860e-06, -5.91293519e-06, -1.05618501e-05, -1.56718652e-05, -2.18157675e-05, -2.21090413e-05, -1.62681827e-05, -1.15394150e-05, -7.35144840e-06, -3.26711961e-06, -1.13179840e-06, -4.98940426e-06, -8.85062400e-06, -1.30997241e-05, -1.82144904e-05, -1.84380206e-05, -1.35542105e-05, -9.59566933e-06, -6.08572736e-06, -2.65887866e-06, -1.01367493e-06, -4.26561428e-06, -7.52358210e-06, -1.11123145e-05, -1.54364170e-05, -1.56106762e-05, -1.14666063e-05, -8.10436813e-06, -5.12021325e-06, -2.20401580e-06, -9.09635219e-07, -3.68808492e-06, -6.47385696e-06, -9.54499774e-06, -1.32485484e-05, -1.33870126e-05, -9.82651000e-06, -6.93532820e-06, -4.36710525e-06, -1.85539375e-06, -8.18735487e-07, -3.22003825e-06, -5.62928972e-06, -8.28724023e-06, -1.14948289e-05, -1.16066676e-05, -8.51461300e-06, -6.00201292e-06, -3.76846447e-06, -1.58258263e-06, -7.39498375e-07, -2.83553072e-06, -4.93973403e-06, -7.26259532e-06, -1.00675643e-05, -1.01591886e-05, -7.44886802e-06, -5.24508141e-06, -3.28481428e-06, -1.36524977e-06, -6.70378654e-07, -2.51585061e-06, -4.36947221e-06, -6.41683391e-06, -8.89049170e-06, -8.96649362e-06, -6.57134478e-06, -4.62275193e-06, -2.88851857e-06, -1.18941352e-06, -6.09944266e-07, -2.24723408e-06, -3.89250545e-06, -5.71062310e-06, -7.90838203e-06, -7.97212033e-06, -5.84020108e-06, -4.10491293e-06, -2.55976192e-06, -1.04521314e-06, -5.56935277e-07, -2.01937837e-06, -3.48954882e-06, -5.11487451e-06, -7.08044308e-06, -7.13442114e-06, -5.22460778e-06, -3.66942504e-06, -2.28403951e-06, -9.25535005e-07, -5.10270809e-07, -1.82444705e-06, -3.14605040e-06, -4.60769843e-06, -6.37601988e-06, -6.42213308e-06, -4.70144141e-06, -3.29971408e-06, -2.05053857e-06, -8.25151346e-07, -4.69036365e-07, -1.65639949e-06, -2.85086708e-06, -4.17237243e-06, -5.77171340e-06, -5.81141694e-06, -4.25308644e-06, -2.98317354e-06, -1.85106614e-06, -7.40148607e-07, -4.32460268e-07, -1.51051631e-06, -2.59534818e-06, -3.79594053e-06, -5.24941379e-06, -5.28384317e-06, -3.86593183e-06, -2.71007866e-06, -1.67932183e-06, -6.67554332e-07, -3.99893480e-07, -1.38306928e-06, -2.37269478e-06, -3.46823890e-06, -4.79492701e-06, -4.82497671e-06, -3.52932648e-06, -2.47282924e-06, -1.53039912e-06, -6.05077048e-07, -3.70789934e-07, -1.27108103e-06, -2.17750403e-06, -3.18120783e-06, -4.39700398e-06, -4.42338614e-06, -3.23483960e-06, -2.26541715e-06, -1.40042869e-06, -5.50929371e-07}); input.linspace(1); gradO = 1; nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 1., 1}, {5}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_2) { NDArray input('c', {2,3,4,10}); NDArray gradO('c', {2,3,4,10}); NDArray exp('c', {2,3,4,10}, {-1.06179598e-03, -2.70050880e-03, -4.02126182e-03, -2.58826977e-03, -2.16024881e-03, -2.20575323e-03, -2.75954953e-03, -4.42477595e-03, -2.89176637e-03, -9.46942251e-04, -1.32603094e-03, -3.34868953e-03, -4.98152524e-03, -3.21313459e-03, -2.68880837e-03, -2.75207381e-03, -3.45109636e-03, -5.54159656e-03, -3.61320702e-03, -1.16457068e-03, -1.70158676e-03, -4.26037982e-03, -6.33032294e-03, -4.09416296e-03, -3.43742501e-03, -3.52900685e-03, -4.43827361e-03, -7.13911094e-03, -4.64041065e-03, -1.46419462e-03, -2.26016506e-03, -5.59943309e-03, -8.30824208e-03, -5.39253885e-03, -4.54709725e-03, -4.68666852e-03, -5.91615774e-03, -9.53640230e-03, -6.17204653e-03, -1.89000927e-03, -3.14102764e-03, -7.67878769e-03, -1.13740638e-02, -7.41857197e-03, -6.29213545e-03, -6.51977258e-03, -8.27047508e-03, -1.33656031e-02, -8.59564263e-03, -2.51553906e-03, -4.64272872e-03, -1.11560747e-02, -1.64905936e-02, -1.08321551e-02, -9.26420093e-03, -9.67171416e-03, -1.23506878e-02, -2.00199075e-02, -1.27442302e-02, -3.45497206e-03, -7.49545777e-03, -1.76018942e-02, -2.59558801e-02, -1.72390267e-02, -1.49321631e-02, -1.57669969e-02, -2.03234926e-02, -3.30405571e-02, -2.06389092e-02, -4.78462130e-03, -1.38390735e-02, -3.14943902e-02, -4.63354364e-02, -3.13667879e-02, -2.77508944e-02, -2.98541505e-02, -3.89749333e-02, -6.32867143e-02, -3.77952419e-02, -5.26650995e-03, -3.16195861e-02, -6.90807998e-02, -1.01725549e-01, -7.13700354e-02, -6.54785037e-02, -7.25797564e-02, -9.49372798e-02, -1.47399038e-01, -7.21285641e-02, 2.15010419e-02, -8.06625858e-02, -1.79638922e-01, -2.66877055e-01, -1.64447501e-01, -1.00968637e-01, -2.75682062e-02, 1.13596700e-01, 3.32260162e-01, 5.96845448e-01, 8.13161016e-01, 9.52381015e-01, 8.13161016e-01, 5.96845508e-01, 3.32260162e-01, 1.13596708e-01, -2.75682174e-02, -1.37202948e-01, -2.71326721e-01, -1.84127048e-01, -7.94974267e-02, 3.29870060e-02, -7.39035010e-02, -1.60488203e-01, -1.04997143e-01, -8.06594491e-02, -7.25797564e-02, -7.87955597e-02, -1.11791104e-01, -7.58660138e-02, -3.48676592e-02, -4.96974029e-03, -4.04525958e-02, -6.82792515e-02, -4.20900472e-02, -3.21968049e-02, -2.98541524e-02, -3.36477235e-02, -4.95737195e-02, -3.37007530e-02, -1.48636252e-02, -4.92655952e-03, -2.17927732e-02, -3.49853337e-02, -2.15152260e-02, -1.66727621e-02, -1.57669988e-02, -1.81730352e-02, -2.73226351e-02, -1.85334161e-02, -7.91355036e-03, -3.57114570e-03, -1.33136865e-02, -2.09431648e-02, -1.29161589e-02, -1.01064872e-02, -9.67171136e-03, -1.12970043e-02, -1.71830691e-02, -1.16271935e-02, -4.84848116e-03, -2.59314431e-03, -8.91274121e-03, -1.38697922e-02, -8.58002994e-03, -6.75992295e-03, -6.51977304e-03, -7.68158771e-03, -1.17703741e-02, -7.94785097e-03, -3.25604435e-03, -1.94202550e-03, -6.36530807e-03, -9.84015409e-03, -6.10316684e-03, -4.83274320e-03, -4.68666898e-03, -5.55526093e-03, -8.55536573e-03, -5.76688722e-03, -2.33053416e-03, -1.50016253e-03, -4.76644421e-03, -7.33569637e-03, -4.55961144e-03, -3.62428720e-03, -3.52900638e-03, -4.20164689e-03, -6.49448857e-03, -4.37143166e-03, -1.74761284e-03, -1.19028054e-03, -3.69978836e-03, -5.67591935e-03, -3.53418733e-03, -2.81759514e-03, -2.75207404e-03, -3.28776496e-03, -5.09600528e-03, -3.42601724e-03, -1.35771628e-03, -9.65878542e-04, -2.95373448e-03, -4.52052988e-03, -2.81889434e-03, -2.25270819e-03, -2.20575323e-03, -2.64216494e-03, -4.10421193e-03, -2.75646802e-03, -1.08450721e-03, -7.98697409e-04, -2.41194153e-03, -3.68447183e-03, -2.30037421e-03, -1.84193184e-03, -1.80714857e-03, -2.16938392e-03, -3.37567786e-03, -2.26523401e-03, -8.85842834e-04, -6.71049987e-04, -2.00629188e-03, -3.06024216e-03, -1.91263494e-03, -1.53396139e-03, -1.50748459e-03, -1.81288645e-03, -2.82496959e-03, -1.89429161e-03, -7.36965681e-04, -5.71501616e-04, -1.69480499e-03, -2.58198148e-03, -1.61517004e-03, -1.29717519e-03, -1.27655920e-03, -1.53747783e-03, -2.39865575e-03, -1.60740130e-03, -6.22576685e-04, -4.92433901e-04, -1.45049067e-03, -2.20754091e-03, -1.38200901e-03, -1.11122860e-03, -1.09486456e-03, -1.32032647e-03, -2.06194492e-03, -1.38099224e-03, -5.32818493e-04}); input.linspace(-10, 0.1); gradO = 1; nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 1., 1}, {2}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_3) { NDArray input('c', {2,3,4,10}); NDArray gradO('c', {2,3,4,10}); NDArray exp('c', {2,3,4,10}, {-6.78180193e-04, -1.06947345e-03, -1.50362519e-03, -1.47711602e-03, -1.45060697e-03, -1.42409769e-03, -1.39758852e-03, -1.37107936e-03, -8.79839936e-04, -4.27795108e-04, -8.62496032e-04, -1.34585891e-03, -1.88281795e-03, -1.84591592e-03, -1.80901436e-03, -1.77211256e-03, -1.73521065e-03, -1.69830909e-03, -1.08184782e-03, -5.13895764e-04, -1.13227055e-03, -1.74428569e-03, -2.42520543e-03, -2.37169350e-03, -2.31818156e-03, -2.26466986e-03, -2.21115816e-03, -2.15764646e-03, -1.36136822e-03, -6.26647263e-04, -1.54878304e-03, -2.34815548e-03, -3.23930010e-03, -3.15753091e-03, -3.07576265e-03, -2.99399323e-03, -2.91222427e-03, -2.83045508e-03, -1.76287338e-03, -7.75904860e-04, -2.23870482e-03, -3.32566188e-03, -4.54067392e-03, -4.40674182e-03, -4.27281018e-03, -4.13887901e-03, -4.00494691e-03, -3.87101574e-03, -2.36659218e-03, -9.72117065e-04, -3.49745504e-03, -5.05724549e-03, -6.80746930e-03, -6.56589260e-03, -6.32431870e-03, -6.08274434e-03, -5.84116904e-03, -5.59959421e-03, -3.32604628e-03, -1.21081201e-03, -6.14068285e-03, -8.55270587e-03, -1.12749329e-02, -1.07723922e-02, -1.02698486e-02, -9.76730697e-03, -9.26476624e-03, -8.76222178e-03, -4.94601438e-03, -1.37539487e-03, -1.30690653e-02, -1.72132626e-02, -2.19351258e-02, -2.06174850e-02, -1.92998387e-02, -1.79821979e-02, -1.66645572e-02, -1.53469117e-02, -7.72346184e-03, -5.22134826e-04, -3.99478227e-02, -4.78655733e-02, -5.70126995e-02, -5.16961850e-02, -4.63796593e-02, -4.10631336e-02, -3.57466117e-02, -3.04300785e-02, -9.11374856e-03, 1.14024431e-02, -2.35893592e-01, -2.17480078e-01, -1.88097835e-01, -1.38812393e-01, -8.95269737e-02, -4.02415469e-02, 9.04385652e-03, 5.83292767e-02, 1.78530529e-01, 2.96026409e-01, 4.16666657e-01, 2.79557735e-01, 1.36546940e-01, 7.49502778e-02, 1.33536234e-02, -4.82430384e-02, -1.09839723e-01, -1.71436355e-01, -2.33033031e-01, -2.74476141e-01, 1.54189002e-02, -8.10869783e-03, -3.24862264e-02, -3.88403721e-02, -4.51945364e-02, -5.15486896e-02, -5.79028539e-02, -6.42570183e-02, -5.45457527e-02, -4.61437553e-02, -2.29711179e-04, -8.06892477e-03, -1.63567103e-02, -1.78351123e-02, -1.93135180e-02, -2.07919199e-02, -2.22703181e-02, -2.37487257e-02, -1.87229179e-02, -1.43175106e-02, -1.37000845e-03, -5.16320160e-03, -9.21433326e-03, -9.76086594e-03, -1.03073996e-02, -1.08539313e-02, -1.14004640e-02, -1.19469995e-02, -9.08647850e-03, -6.55380823e-03, -1.23490533e-03, -3.45137389e-03, -5.83263952e-03, -6.09064987e-03, -6.34865928e-03, -6.60666777e-03, -6.86467718e-03, -7.12268520e-03, -5.30054048e-03, -3.67741752e-03, -9.94500006e-04, -2.44303374e-03, -4.00528917e-03, -4.14666394e-03, -4.28803731e-03, -4.42941114e-03, -4.57078544e-03, -4.71215881e-03, -3.45545518e-03, -2.33156094e-03, -7.93270417e-04, -1.81236281e-03, -2.91444198e-03, -3.00004939e-03, -3.08565609e-03, -3.17126350e-03, -3.25687067e-03, -3.34247784e-03, -2.42513884e-03, -1.60246110e-03, -6.39747130e-04, -1.39506557e-03, -2.21352675e-03, -2.26921216e-03, -2.32489733e-03, -2.38058274e-03, -2.43626791e-03, -2.49195332e-03, -1.79354590e-03, -1.16592250e-03, -5.23828785e-04, -1.10576022e-03, -1.73730974e-03, -1.77553250e-03, -1.81375467e-03, -1.85197743e-03, -1.89020019e-03, -1.92842260e-03, -1.37922564e-03, -8.84913374e-04, -4.35433642e-04, -8.97393096e-04, -1.39935245e-03, -1.42670958e-03, -1.45406683e-03, -1.48142409e-03, -1.50878134e-03, -1.53613824e-03, -1.09309505e-03, -6.93831593e-04, -3.66991735e-04, -7.42538832e-04, -1.15100679e-03, -1.17125409e-03, -1.19150116e-03, -1.21174823e-03, -1.23199564e-03, -1.25224248e-03, -8.87364266e-04, -5.58210537e-04, -3.13144788e-04, -6.24410110e-04, -9.63238359e-04, -9.78639582e-04, -9.94040747e-04, -1.00944215e-03, -1.02484343e-03, -1.04024459e-03, -7.34565372e-04, -4.58585098e-04, -2.70129647e-04, -5.32291830e-04, -8.17865424e-04, -8.29851197e-04, -8.41836852e-04, -8.53822567e-04, -8.65808397e-04, -8.77794111e-04, -6.18013146e-04, -3.83307983e-04, -2.35282409e-04, -4.59096394e-04, -7.03040219e-04, -7.12549896e-04, -7.22059398e-04, -7.31569016e-04, -7.41078693e-04, -7.50588137e-04, -5.27105702e-04, -3.25074652e-04}); input.linspace(-10, 0.1); gradO = 1; nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 1., 1}, {7}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_4) { NDArray input('c', {2,3,4,10}); NDArray gradO('c', {2,3,4,10}); NDArray exp('c', {2,3,4,10}, {-0.00119282, -0.00116995, -0.00114708, -0.00112421, -0.00110134, -0.00107847, -0.00105559, -0.00103272, -0.00100985, -0.00098698, -0.00150102, -0.00146918, -0.00143734, -0.0014055 , -0.00137366, -0.00134182, -0.00130998, -0.00127814, -0.0012463 , -0.00121446, -0.00194534,-0.00189916, -0.00185299, -0.00180681, -0.00176064, -0.00171446, -0.00166829, -0.00162211, -0.00157593, -0.00152976, -0.0026189 , -0.00254833, -0.00247776, -0.00240719, -0.00233662, -0.00226605, -0.00219548, -0.00212491, -0.00205434, -0.00198377, -0.00370962, -0.00359401, -0.00347839, -0.00336277, -0.00324716, -0.00313154, -0.00301593, -0.00290031, -0.00278469, -0.00266908, -0.00564327, -0.00543464, -0.00522602, -0.00501739, -0.00480876, -0.00460013, -0.0043915 , -0.00418288, -0.00397425, -0.00376562, -0.00955302, -0.00911865, -0.00868428, -0.00824992, -0.00781555, -0.00738118, -0.00694682, -0.00651245, -0.00607808, -0.00564371, -0.01927758, -0.01813637, -0.01699515, -0.01585394, -0.01471272, -0.01357151, -0.01243029, -0.01128908, -0.01014786, -0.00900664, -0.05409876, -0.04945958, -0.04482041, -0.04018124, -0.03554206, -0.03090289, -0.02626371, -0.02162454, -0.01698537, -0.01234619, -0.26145172, -0.214688 , -0.16792431, -0.12116055, -0.07439683, -0.02763309, 0.01913062, 0.06589434, 0.11265809, 0.15942183, 0.25974026, 0.19902176, 0.13830325, 0.07758474, 0.01686624, -0.04385226, -0.10457078, -0.16528927, -0.22600779, -0.2867263 , -0.01177884, -0.0173331 , -0.02288735, -0.02844159, -0.03399584, -0.0395501 , -0.04510435, -0.05065861, -0.05621284, -0.0617671 , -0.00944993, -0.01073084, -0.01201174, -0.01329265, -0.01457355, -0.01585446, -0.01713536, -0.01841627, -0.01969717, -0.02097807, -0.00589878, -0.00637122, -0.00684368, -0.00731612, -0.00778858, -0.00826102, -0.00873347, -0.00920592, -0.00967837, -0.01015082, -0.00390961, -0.00413245, -0.00435528, -0.00457812, -0.00480095, -0.00502378, -0.00524662, -0.00546945, -0.00569229, -0.00591512, -0.00275609, -0.00287813, -0.00300018, -0.00312222, -0.00324427, -0.00336631, -0.00348836, -0.0036104 , -0.00373245, -0.00385449, -0.00203982, -0.00211371, -0.00218759, -0.00226147, -0.00233536, -0.00240924, -0.00248312, -0.00255701, -0.00263089, -0.00270478, -0.00156781, -0.00161586, -0.00166391, -0.00171197, -0.00176002, -0.00180807, -0.00185612, -0.00190417, -0.00195223, -0.00200028, -0.00124141, -0.00127439, -0.00130737, -0.00134035, -0.00137333, -0.00140631, -0.00143929, -0.00147227, -0.00150525, -0.00153822, -0.00100674, -0.00103034, -0.00105394, -0.00107754, -0.00110115, -0.00112475, -0.00114835, -0.00117195, -0.00119556, -0.00121916, -0.00083255, -0.00085002, -0.00086748, -0.00088495, -0.00090242, -0.00091989, -0.00093735, -0.00095482, -0.00097229, -0.00098976, -0.0006998 , -0.00071308, -0.00072637, -0.00073965, -0.00075294, -0.00076623, -0.00077951, -0.0007928 , -0.00080609, -0.00081937, -0.00059635, -0.00060669, -0.00061703, -0.00062737, -0.00063771, -0.00064805, -0.00065839, -0.00066873, -0.00067906, -0.0006894 , -0.0005142 , -0.0005224 , -0.00053061, -0.00053881, -0.00054701, -0.00055522, -0.00056342, -0.00057162, -0.00057983, -0.00058803}); input.linspace(-10, 0.1); gradO = 1; nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 1., 1}, {12}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_5) { NDArray input('c', {2,2,2,5}); NDArray gradO('c', {2,2,2,5}); NDArray exp('c', {2,2,2,5}, {6.2497472e-03, -3.4008762e-03, -1.5232352e-02, 2.3018382e-04, 1.3257053e-02, 7.1492628e-03, -5.4330104e-03, -2.0878183e-02, 1.5153568e-03, 2.0571884e-02, 6.7926152e-03, -1.0990440e-02, -3.2685306e-02, 7.2436016e-03, 4.2120241e-02, -1.3439789e-02, -3.4284033e-02, -4.4852167e-02, 8.8073254e-02, 2.2223940e-01, 4.0824831e-01, 2.1201703e-01, 3.8555145e-02, -3.1969927e-02, -3.0673094e-02, 5.2034661e-02, 1.0463811e-02, -3.6619946e-02, -1.3280880e-02, 5.9767403e-03, 2.3028374e-02, 2.0452859e-03, -2.2533152e-02, -6.1039329e-03, 7.2805062e-03, 1.4290780e-02, 3.8017845e-04, -1.6107092e-02,-3.6896234e-03, 6.4357026e-03}); input.linspace(-20, 1); // gradO.linspace(0.1, 0.1); gradO = 1; nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 1., 0.5}, {2}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_6) { NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); NDArray gradO('c', {1,1,1,5}); NDArray exp('c', {1,1,1,5}, {0.06926288, 0.04360996, 0.01795704, -0.00769587, -0.0333488}); // gradO.linspace(-1.5, 0.1); gradO = 1; nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {10}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_7) { NDArray input('c', {2,2,2,5}); NDArray gradO('c', {2,2,2,5}); input.linspace(-20, 1); gradO.linspace(-1.5, 0.1); const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); nd4j::ops::lrn opFF; nd4j::ops::lrn_bp opBP; const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_8) { NDArray input('c', {1,1,1,5}, {1, 2, 3, 4, 5}); NDArray gradO('c', {1,1,1,5}, {2, 3, 4, 5, 6}); const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); nd4j::ops::lrn opFF; nd4j::ops::lrn_bp opBP; const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_9) { NDArray input('c', {1,1,1,5}, {1,2,3,4,5}); NDArray gradO('c', {1,1,1,5}, {1, 1, 1, 1, 1}); NDArray exp('c', {1,1,1,5}, {0.1084472 , 0.03816165, 0.00978456, -0.01859251,-0.02511311}); nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {3}); auto gradI = results->at(0); // for (int i = 0; i < exp.lengthOf(); ++i) // printf("%10.5f %10.5f\n", exp.e(i), gradI->e(i)); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_10) { NDArray input('c', {1,1,1,1}, {1}); NDArray gradO('c', {1,1,1,1}, {1}); NDArray exp('c', {1,1,1,1}, {0.19245008}); nd4j::ops::lrn_bp op; auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {1}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_1) { NDArray input('c', {2,2,2,5}); NDArray exp('c', {2,2,2,5}, {-0.42923987, -0.3623817 , -0.3152079 , -0.34268343, -0.3836809, -0.43648192, -0.3652726 , -0.31428117, -0.3379276 , -0.3731494 , -0.45129365, -0.37083852, -0.3111639 , -0.3260225 , -0.34698898, -0.4975186 , -0.3831305 , -0.2847474 , -0.25607377, -0.18569534, 0., 0.18569534, 0.25607377, 0.38411066, 0.52075565,0.33633637, 0.32117262, 0.30966178, 0.37259716, 0.45631808, 0.36986336, 0.33643705, 0.31394684, 0.36608824, 0.43857202, 0.3821113 , 0.34197718, 0.31508508, 0.36284128, 0.4303756 }); input.linspace(-20, 1); nd4j::ops::lrn op; auto results = op.execute({&input}, {1., 2., 0.5}, {2}); auto output = results->at(0); ASSERT_EQ(*output, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_2) { NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); NDArray exp('c', {1,1,1,5}, {0.09530295, 0.1906059 , 0.28590885, 0.3812118 , 0.47651473}); nd4j::ops::lrn op; auto results = op.execute({&input}, {0.1, 2., 0.5}, {5}); auto output = results->at(0); ASSERT_EQ(*output, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_3) { NDArray input('c', {1,1,1,1}, {1.}); NDArray exp('c', {1,1,1,1}, {0.69006556}); nd4j::ops::lrn op; auto results = op.execute({&input}, {0.1, 2., 0.5}, {5}); auto output = results->at(0); ASSERT_EQ(*output, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_4) { NDArray input('c', {1,1,1,1}, {1.}); NDArray exp('c', {1,1,1,1}, {0.69006556}); nd4j::ops::lrn op; auto results = op.execute({&input}, {0.1, 2., 0.5}, {0}); auto output = results->at(0); ASSERT_EQ(*output, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_5) { NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); NDArray exp('c', {1,1,1,5}, {0.69006556, 0.70272833, 0.7051508 , 0.7060045 , 0.7064008}); nd4j::ops::lrn op; auto results = op.execute({&input}, {0.1, 2., 0.5}, {0}); auto output = results->at(0); ASSERT_EQ(*output, exp); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_1) { NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); NDArray y('c', {4}, {0, 0, 0, 0}, nd4j::DataType::INT64); NDArray z('c', {4}, {1, 1, 1, 1}, nd4j::DataType::BOOL); NDArray expV('c', {4}, {1, 0, 0, 0}, nd4j::DataType::BOOL); nd4j::ops::in_top_k op; Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {}); // z.printIndexedBuffer(); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(expV.isSameShape(z)); ASSERT_TRUE(expV.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cube_1) { NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); NDArray exp('c', {2, 3}, {1., 8., 27., 64., 125, 216}); nd4j::ops::cube op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cube_bp_1) { NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); NDArray gradO('c', {2, 3}, nd4j::DataType::DOUBLE); NDArray exp('c', {2, 3}, {1.5, 6., 13.5, 24., 37.5, 54}); gradO = 0.5; nd4j::ops::cube_bp op; auto result = op.execute({&x, &gradO}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); // z->printIndexedBuffer(); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////// // CONSTANT mode 2D TEST_F(DeclarableOpsTests12, pad_tests1) { NDArray input('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::FLOAT32); NDArray paddings('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT32); NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, nd4j::DataType::FLOAT32); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // REFLECT mode 2D TEST_F(DeclarableOpsTests12, pad_tests2) { float inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; float expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, pad_tests3) { float inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; float expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, pad_tests4) { float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; float expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0, 0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); // for(int i = 0; i < expected.lengthOf(); ++i) { // float one = expected.e(i); // float two = result->e(i); // if(one != two) // printf("%i : %f, %f\n", i, one, two); // } delete results; } //////////////////////////////////////////////////////////////////// // REFLECT mode 3D TEST_F(DeclarableOpsTests12, pad_tests5) { float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; float expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, pad_tests6) { float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; float expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // CONSTANT mode 4D TEST_F(DeclarableOpsTests12, pad_tests7) { float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; float expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // REFLECT mode 4D TEST_F(DeclarableOpsTests12, pad_tests8) { float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; float expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } ////////////////////////////////////////////////////////////////// // SYMMETRIC mode 4D TEST_F(DeclarableOpsTests12, pad_tests9) { float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; float expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests10) { auto input = NDArrayFactory::create('c', {2,3,4}); auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); auto expected = NDArrayFactory::create('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); input = 1.f; //input.assign(1.); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests11) { auto input = NDArrayFactory::create('c', {2,3,4}); auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); auto expected = NDArrayFactory::create('c', {2,4,4}, {1., 2., 3., 4., 5., 6., 7., 8., 9.,10.,11.,12., 5., 6., 7., 8.,13.,14.,15.,16.,17.,18.,19.,20.,21.,22.,23.,24.,17.,18.,19.,20.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests12) { auto input = NDArrayFactory::create('c', {2,3,4,5}); auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); auto expected = NDArrayFactory::create('c', {2,4,5,5}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.,100., 96., 97., 98., 99.,100., 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120., 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests13) { auto input = NDArrayFactory::create('c', {5}); auto paddings = NDArrayFactory::create('c', {1,2}, {2,3}); auto expected = NDArrayFactory::create('c', {10}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests14) { auto input = NDArrayFactory::create('c', {1,5}); auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,2,3}); auto expected = NDArrayFactory::create('c', {1,10}, {2., 1., 1., 2., 3., 4., 5., 5., 4., 3.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests15) { auto input = NDArrayFactory::create('c', {1,5}); auto paddings = NDArrayFactory::create('c', {2,2}, {1,1,0,0}); auto expected = NDArrayFactory::create('c', {3,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 1., 2., 3., 4., 5.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests16) { auto input = NDArrayFactory::create('c', {5,1}); auto paddings = NDArrayFactory::create('c', {2,2}, {2,3,0,0}); auto expected = NDArrayFactory::create('c', {10,1}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests17) { auto input = NDArrayFactory::create('c', {5,1}); auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,1,0}); auto expected = NDArrayFactory::create('c', {5,2}, {1.,1., 2.,2., 3.,3., 4.,4., 5.,5.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests18) { auto input = NDArrayFactory::create('c', {5}); auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); auto expected = NDArrayFactory::create('c', {5}, {1.,2.,3.,4.,5.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests19) { auto input = NDArrayFactory::create('c', {5,1}); auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); auto expected = NDArrayFactory::create('c', {5,1}, {1., 2., 3., 4., 5.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests20) { auto input = NDArrayFactory::create('c', {1,5}); auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); auto expected = NDArrayFactory::create('c', {1,5}, {1., 2., 3., 4., 5.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests21) { auto input = NDArrayFactory::create('c', {1,3,1,5}); auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); auto expected = NDArrayFactory::create('c', {1,4,2,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 6., 7., 8., 9.,10., 6., 7., 8., 9.,10., 11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests22) { auto input = NDArrayFactory::create('c', {1,1}); auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 0,0}); auto expected = NDArrayFactory::create('c', {1,1}, {1.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests23) { auto input = NDArrayFactory::create('c', {1,1}); auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 1,0}); auto expected = NDArrayFactory::create('c', {1,2}, {0.,1.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printShapeInfo("r"); // expected.printShapeInfo("e"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests24) { auto input = NDArrayFactory::create('c', {1}); auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); auto expected = NDArrayFactory::create('c', {1}, {1.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests25) { auto input = NDArrayFactory::create('c', {1}); auto paddings = NDArrayFactory::create('c', {1,2}, {1,1}); auto expected = NDArrayFactory::create('c', {3}, {1.,1.,1}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests26) { auto input = NDArrayFactory::create('c', {1}); auto paddings = NDArrayFactory::create('c', {1,2}, {3,2}); auto expected = NDArrayFactory::create('c', {6}, {0., 0., 0., 1., 0., 0.}); input.linspace(1.f); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests27) { NDArray input('c', {2,3}, nd4j::DataType::FLOAT32); NDArray paddings('c', {2,2}, {0,0,0,1}, nd4j::DataType::INT32); NDArray exp('c', {2,4}, {1,1,1,0,1,1,1,0}, nd4j::DataType::FLOAT32); NDArray z('c', {2,4}, nd4j::DataType::FLOAT32); input = 1.; nd4j::ops::pad op; Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant // z.printIndexedBuffer(); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(exp.isSameShapeStrict(&z)); ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests28) { NDArray input('c', {1,111,111,32}, nd4j::DataType::FLOAT32); NDArray paddings('c', {4,2}, {0,0,0,1,0,1,0,0}, nd4j::DataType::INT32); NDArray z('c', {1,112,112,32}, nd4j::DataType::FLOAT32); input = 1.; nd4j::ops::pad op; Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant // z.printIndexedBuffer(); NDArray sum = z.reduceNumber(nd4j::reduce::Sum); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(sum.e(0), 111*111*32); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests29) { auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); // auto value(10.0); auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); nd4j::ops::pad op; auto res = op.execute({&in, &pad}, {10.0}, {0}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests30) { auto in = NDArrayFactory::create({1., 11., 111., 11., 1.}); auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); auto exp = NDArrayFactory::create({1., 1., 11., 111., 11., 1., 1.}); nd4j::ops::pad op; auto res = op.execute({&in, &pad}, {10.0}, {2}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests31) { auto in = NDArrayFactory::create({1., 11., 111., 1111., 11111.}); // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); // auto value(10.0); auto exp = NDArrayFactory::create({11., 1., 11., 111., 1111., 11111., 1111.}); nd4j::ops::pad op; auto res = op.execute({&in, &pad}, {10.0}, {1}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests32) { auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 4., 5.,6,7,8,9}); auto pad = NDArrayFactory::create('c', {2,2}, {1, 2, 2, 3}); auto exp = NDArrayFactory::create('c', {6,8}, {2, 1, 1, 2, 3, 3, 2, 1, 2, 1, 1, 2, 3, 3, 2, 1, 5, 4, 4, 5, 6, 6, 5, 4, 8, 7, 7, 8, 9, 9, 8, 7, 8, 7, 7, 8, 9, 9, 8, 7, 5, 4, 4, 5, 6, 6, 5, 4}); nd4j::ops::pad op; auto res = op.execute({&in, &pad}, {10.0}, {2}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests33) { auto in = NDArrayFactory::create('c', {2,3,4}, {1, 2, 3, 4,5, 6, 7, 8,9,10,11,12,13, 14, 15, 16,17, 18, 19, 20,21, 22, 23, 24}); auto pad = NDArrayFactory::create('c', {3,2}, {1, 2, 2, 3, 3,3}); auto exp = NDArrayFactory::create('c', {5,8,10}, { 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10.,7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.}); nd4j::ops::pad op; auto res = op.execute({&in, &pad}, {10.0}, {2}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests34) { NDArray input('c', {5}, {0.778786, 0.801198, 0.724375, 0.230894, 0.727141}, nd4j::DataType::FLOAT32); NDArray paddings('c', {1,2}, {1,1}, nd4j::DataType::INT32); NDArray expected('c', {7}, {10., 0.778786, 0.801198, 0.724375, 0.230894, 0.727141, 10.}, nd4j::DataType::FLOAT32); NDArray z('c', {7}, nd4j::DataType::FLOAT32); nd4j::ops::pad op; Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(expected.isSameShapeStrict(&z)); ASSERT_TRUE(expected.equalsTo(z)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 2D TEST_F(DeclarableOpsTests12, Pad_1) { float inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; float expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // REFLECT mode 2D TEST_F(DeclarableOpsTests12, Pad_2) { float inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; float expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, Pad_3) { float inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; float expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, Pad_4) { float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; float expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // REFLECT mode 3D TEST_F(DeclarableOpsTests12, Pad_5) { float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; float expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, Pad_6) { float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; float expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // CONSTANT mode 4D TEST_F(DeclarableOpsTests12, Pad_7) { float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; float expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } //////////////////////////////////////////////////////////////////// // REFLECT mode 4D TEST_F(DeclarableOpsTests12, Pad_8) { float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; float expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } ////////////////////////////////////////////////////////////////// // SYMMETRIC mode 4D TEST_F(DeclarableOpsTests12, Pad_9) { float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; float expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto *result = results->at(0); // result->printIndexedBuffer(); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; } TEST_F(DeclarableOpsTests12, Test_Expose_1) { auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); nd4j::ops::expose op; auto result = op.execute({&input0, &input1}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z0 = result->at(0); auto z1 = result->at(1); ASSERT_TRUE(input0.equalsTo(z0)); ASSERT_TRUE(input1.equalsTo(z1)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) { auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); // auto value(10.0); auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); nd4j::ops::pad op; auto res = op.execute({&in, &pad}, {10.0}, {0}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); // res->at(0)->printIndexedBuffer("PAD_SGO"); // exp.printIndexedBuffer("PAD_EXP"); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; }