cavis/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp
shugeo dc0036f2c6
Shugeo image resize bicubic (#56)
* Added implementation files for image_resize and resize_bicubic ops.

* Image resize and image.resize_bicubic ops implementation. Initial revision.

* Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation.

* Refactored resize methods.

* Added processing for Mitchelcubic algorithm.

* Added check for input/output sizes.

* Added int and float types for crop_and_resize op.

* Refactored crop_and_resize output type check.

* Added helper for bicubic interpolation as TF v.1 does.

* Added TF v.1 bicubic helper for cuda platform.

* Added cached class for bicubic algorithm.

* Refactored cuda implementation for crop_and_resize helper to use proper output type.

* Added facilities for bicubic interpolation.

* Portion bicubic interpolation from TF.

* Added tests for resize_bilinear testing.

* Working implementation of bicubic interpolation and tests.

* Refactored routines with image_resize bicubic op helper.

* Refactored code with coding standards.

* Refactored cpu helpers for resize_bicubic op.

* Refactored bicubic helpers.

* Added bicubic resize facilities.

* Implementing cuda kernels for bicubic interpolation. Implementation step.

* Cuda implementation of resize_bicubic op helper.

* Refactor image.resize_bicubic op helpers.

* Refactored helpers for resize_bicubic. Added error checking with cuda implementation.

* Refactored cuda implementation of resize_bicubic op helper. The first working revision.

* Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision.

* Intermediate bicubic interpolation helper for cuda.

* Refactored cpu helper for resize_bicubic.

* Multithreaded cuda implementation for resize_bicubic.

* Fixed merge issues.

* Refactored nlp helpers.

* Replicated resize_bicubic for 3D also.

* Eliminated waste comments of unused code.

* Eliminated waste comments with unused code.

* Eliminated waste template definitions.

* Eliminated waste debug code.

* Eliminated waste comments.

* Fixed multithreading with helpers.

* Fixed test suites for float and double in float point input lists.

* Fixed usage of reshape with 3D/4D on resizes.

* Final fixes.

* Fixed resize_neighbor op problem.
2019-11-20 21:11:04 +02:00

3121 lines
130 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 8/4/2018.
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
using namespace nd4j;
class DeclarableOpsTests11 : public testing::Test {
public:
DeclarableOpsTests11() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
nd4j::ops::listdiff op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444,
-0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {}, {-227.77286});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
// dLdw->printIndexedBuffer();
// dLdw->printShapeInfo();
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308,
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541,
0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698});
NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334,
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769,
-2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991});
NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 ,
1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004,
-0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154,
-1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996});
NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727,
-0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393});
NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167,
-0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {-9.49054});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846,
-1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956});
NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072,
-0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 ,
-0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 7, 7, 1}, {
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
8, 9.1, 10., 11, 12.9, 13.1, 14,
15, 16., 17., 18, 19, 20., 21,
22, 23., 24., 25, 26, 27, 28,
30, 31, 32, 33, 34., 35, 36,
37, 38, 39, 40, 41., 42, 43,
44, 45, 46, 47, 48., 49, 50
});
NDArray expected = NDArrayFactory::create<double>('c', {1, 30, 30, 1}, {
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {2, 10, 8, 3}, {
1. , 2. ,3. ,2.21875, 3.21875, 4.21875, 4. , 5. , 6. ,5.5,
6.5, 7.5, 7., 8., 9. ,8.78125, 9.78125, 10.78125, 10., 11. ,
12., 10.28125 , 11.28125 ,12.28125, 5.875, 6.875, 7.875, 7.09375, 8.09375 ,9.09375,
8.875, 9.875, 10.875, 10.375, 11.375, 12.375 ,11.875 ,12.875 , 13.875, 13.65625,
14.65625, 15.65625, 14.875 ,15.875 ,16.875 , 15.15625, 16.15625, 17.15625, 13., 14.,
15. ,14.21875, 15.21875, 16.21875, 16., 17., 18. ,17.5 ,18.5 , 19.5,
19., 20., 21., 20.78125 ,21.78125 ,22.78125, 22., 23. , 24. , 22.28125,
23.28125 ,24.28125 ,19. , 20., 21., 20.21875, 21.21875, 22.21875 ,22. ,23.,
24. , 23.5, 24.5, 25.5, 25. ,26. ,27., 26.78125 , 27.78125, 28.78125,
28., 29. ,30. ,28.28125, 29.28125, 30.28125, 25., 26., 27. ,26.21875,
27.21875, 28.21875, 28., 29., 30., 29.5 ,30.5 ,31.5 , 31., 32.,
33., 32.78125, 33.78125 ,34.78125 ,34., 35., 36., 34.28125, 35.28125, 36.28125,
31. ,32., 33. , 32.21875, 33.21875, 34.21875, 34. ,35. ,36., 35.5,
36.5 , 37.5 , 37., 38. ,39. ,38.78125, 39.78125, 40.78125, 40., 41.,
42. ,40.28125 ,41.28125, 42.28125, 37. , 38., 39., 38.21875 ,39.21875 ,40.21875,
40. , 41. , 42. , 41.5, 42.5, 43.5 ,43., 44., 45., 44.78125,
45.78125, 46.78125 ,46. ,47. , 48. , 46.28125 , 47.28125, 48.28125, 44.125 ,45.125,
46.125, 45.34375, 46.34375, 47.34375, 47.125, 48.125 ,49.125 ,48.625, 49.625 , 50.625,
50.125 , 51.125, 52.125 ,51.90625 ,52.90625, 53.90625, 53.125, 54.125, 55.125, 53.40625,
54.40625 ,55.40625, 49. ,50. , 51. ,50.21875, 51.21875 ,52.21875 ,52. ,53.,
54. ,53.5 , 54.5, 55.5 ,55. ,56. ,57. ,56.78125 ,57.78125, 58.78125,
58. ,59. ,60. ,58.28125 ,59.28125 ,60.28125, 50.125, 51.125 ,52.125 ,51.34375,
52.34375 ,53.34375 ,53.125, 54.125, 55.125 ,54.625 ,55.625 ,56.625 ,56.125 ,57.125,
58.125, 57.90625 ,58.90625 ,59.90625 ,59.125 ,60.125 ,61.125, 59.40625, 60.40625 ,61.40625,
61. ,62. ,63. ,62.21875, 63.21875, 64.21875 ,64. ,65. ,66. ,65.5 ,
66.5, 67.5, 67. ,68. ,69. ,68.78125 ,69.78125 ,70.78125 ,70., 71. ,
72. ,70.28125 ,71.28125 ,72.28125 ,65.875 ,66.875, 67.875 ,67.09375 ,68.09375 ,69.09375,
68.875 ,69.875 ,70.875, 70.375 ,71.375 ,72.375 ,71.875 ,72.875 ,73.875 ,73.65625,
74.65625 ,75.65625 ,74.875 ,75.875, 76.875 ,75.15625 ,76.15625,
77.15625 ,73. ,74. ,75., 74.21875 ,75.21875 ,76.21875,
76. ,77. ,78. ,77.5 ,78.5 ,79.5 ,79.,
80. ,81. ,80.78125 ,81.78125, 82.78125 ,82. ,83.,
84. ,82.28125 ,83.28125 ,84.28125, 79. ,80. ,81.,
80.21875 ,81.21875 ,82.21875 ,82., 83. ,84. ,83.5,
84.5 ,85.5 ,85. ,86., 87. ,86.78125 ,87.78125,
88.78125 ,88. ,89. ,90., 88.28125 ,89.28125 ,90.28125,
85. ,86. ,87. ,86.21875, 87.21875 ,88.21875 ,88.,
89. ,90. ,89.5 ,90.5, 91.5 ,91. ,92.,
93. ,92.78125 ,93.78125 ,94.78125, 94. ,95. ,96.,
94.28125 ,95.28125 ,96.28125 ,91., 92. ,93. ,92.21875,
93.21875 ,94.21875 ,94. ,95., 96. ,95.5 ,96.5,
97.5 ,97. ,98. ,99., 98.78125 ,99.78125 ,100.78125,
100. ,101. ,102. ,100.28125, 101.28125 ,102.28125, 97.,
98. ,99. ,98.21875 ,99.21875, 100.21875 ,100., 101.,
102. ,101.5 ,102.5 ,103.5, 103. ,104., 105.,
104.78125 ,105.78125 ,106.78125 ,106., 107. ,108., 106.28125,
107.28125 ,108.28125 ,104.125 ,105.125, 106.125 ,105.34375, 106.34375,
107.34375 ,107.125 ,108.125 ,109.125, 108.625 ,109.625, 110.625,
110.125 ,111.125 ,112.125 ,111.90625, 112.90625 ,113.90625, 113.125,
114.125 ,115.125 ,113.40625 ,114.40625, 115.40625 ,109., 110.,
111. ,110.21875 ,111.21875 ,112.21875, 112., 113., 114.,
113.5 ,114.5 ,115.5 ,115., 116., 117., 116.78125,
117.78125 ,118.78125 ,118. ,119., 120., 118.28125, 119.28125,
120.28125 ,110.125 ,111.125 ,112.125, 111.34375, 112.34375, 113.34375,
113.125 ,114.125 ,115.125 ,114.625, 115.625, 116.625, 116.125,
117.125 ,118.125 ,117.90625, 118.90625, 119.90625, 119.125, 120.125,
121.125 ,119.40625 ,120.40625, 121.40625}); //input = 1.f;
input.linspace(1);
auto size = NDArrayFactory::create<int>({10, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 10x8");
// expected.printBuffer("Expect for 10x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 6, 4}, {
1. ,2. ,3. ,4.,
2.625 ,3.625 ,4.625 ,5.625,
5. ,6. ,7. ,8.,
7.375 ,8.375 ,9.375, 10.375,
9. ,10. ,11. ,12.,
9.375 ,10.375 ,11.375 ,12.375,
5.875 ,6.875 ,7.875 , 8.875 ,
7.5 ,8.5 ,9.5 , 10.5 ,
9.875 ,10.875 ,11.875, 12.875,
12.25 ,13.25 ,14.25 , 15.25 ,
13.875 ,14.875 ,15.875, 16.875,
14.25 ,15.25 ,16.25 , 17.25 ,
13. ,14. ,15. ,16.,
14.625 ,15.625 ,16.625 ,17.625,
17. ,18. ,19. ,20.,
19.375 ,20.375 ,21.375 ,22.375,
21. ,22. ,23. ,24.,
21.375 ,22.375 ,23.375 ,24.375,
20.125 ,21.125 ,22.125 ,23.125,
21.75 ,22.75 ,23.75 ,24.75,
24.125 ,25.125 ,26.125 ,27.125,
26.5 ,27.5 ,28.5 ,29.5,
28.125 ,29.125 ,30.125 ,31.125,
28.5 ,29.5 ,30.5 ,31.5,
25. , 26. , 27. , 28.,
26.625 ,27.625 ,28.625 ,29.625,
29. ,30. ,31. ,32.,
31.375 ,32.375 ,33.375 ,34.375,
33. ,34. ,35. ,36.,
33.375 ,34.375 ,35.375 ,36.375,
26.125, 27.125, 28.125, 29.125,
27.75 ,28.75 ,29.75 ,30.75,
30.125 ,31.125 ,32.125 ,33.125,
32.5 ,33.5 ,34.5 ,35.5,
34.125 ,35.125 ,36.125 ,37.125,
34.5 ,35.5 ,36.5 ,37.5
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x6");
// expected.printBuffer("Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 8, 3}, {
1. , 2. , 3. ,
2.21875 ,3.21875 ,4.21875,
4. ,5. ,6. ,
5.5 ,6.5 ,7.5 ,
7. ,8. ,9. ,
8.78125 ,9.78125, 10.78125,
10. ,11., 12. ,
10.28125 ,11.28125, 12.28125,
5.875 , 6.875 , 7.875 ,
7.09375 , 8.09375 , 9.09375,
8.875 , 9.875 ,10.875 ,
10.375 ,11.375 ,12.375 ,
11.875 ,12.875 ,13.875 ,
13.65625 ,14.65625 ,15.65625,
14.875 ,15.875 ,16.875 ,
15.15625 ,16.15625 ,17.15625,
13., 14., 15.,
14.21875 ,15.21875 ,16.21875,
16. ,17. ,18. ,
17.5 ,18.5 ,19.5 ,
19. ,20. ,21. ,
20.78125 ,21.78125 ,22.78125,
22. ,23. ,24. ,
22.28125 ,23.28125 ,24.28125,
20.125 , 21.125 , 22.125,
21.34375 ,22.34375 ,23.34375,
23.125 ,24.125 ,25.125 ,
24.625 ,25.625 ,26.625 ,
26.125 ,27.125 ,28.125 ,
27.90625 ,28.90625 ,29.90625,
29.125 ,30.125 ,31.125 ,
29.40625 ,30.40625 ,31.40625,
25. ,26. ,27. ,
26.21875 ,27.21875 ,28.21875,
28. ,29. ,30. ,
29.5 ,30.5 ,31.5 ,
31. ,32. ,33. ,
32.78125 ,33.78125 ,34.78125,
34. ,35. ,36. ,
34.28125 ,35.28125 ,36.28125,
26.125 ,27.125 , 28.125 ,
27.34375 ,28.34375 ,29.34375,
29.125 ,30.125 ,31.125 ,
30.625 ,31.625 ,32.625 ,
32.125 ,33.125 ,34.125 ,
33.90625 ,34.90625 ,35.90625,
35.125 ,36.125 ,37.125 ,
35.40625 ,36.40625 ,37.40625 });
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x8");
// expected.printBuffer("Expect for 6x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
NDArray input = NDArrayFactory::create<double>('c', {1, 4, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {1, 8, 8, 3}, {
1. ,2. , 3. , 2.21875 , 3.21875 , 4.21875 , 4. , 5. ,
6. ,5.5 , 6.5 , 7.5 , 7. , 8. , 9. , 8.78125 ,
9.78125 ,10.78125 ,10. ,11. ,12. ,10.28125 ,11.28125 ,12.28125 ,
5.875 ,6.875 , 7.875 , 7.09375 , 8.09375 , 9.09375 , 8.875 , 9.875 ,
10.875 ,10.375 , 11.375 , 12.375 , 11.875 , 12.875 , 13.875 , 13.65625,
14.65625 ,15.65625, 14.875 , 15.875 , 16.875 , 15.15625, 16.15625, 17.15625,
13. ,14. , 15. , 14.21875, 15.21875, 16.21875, 16. , 17. ,
18. ,17.5 , 18.5 , 19.5 , 19. , 20. , 21. , 20.78125,
21.78125 ,22.78125, 22. , 23. , 24. , 22.28125, 23.28125, 24.28125,
19. ,20. , 21. , 20.21875, 21.21875, 22.21875, 22. , 23. ,
24. ,23.5 , 24.5 , 25.5 , 25. , 26. , 27. , 26.78125,
27.78125 ,28.78125, 28. , 29. , 30. , 28.28125, 29.28125, 30.28125,
25. ,26. , 27. , 26.21875, 27.21875, 28.21875, 28. , 29. ,
30. ,29.5 , 30.5 , 31.5 , 31. , 32. , 33. , 32.78125,
33.78125 ,34.78125, 34. , 35. , 36. , 34.28125, 35.28125, 36.28125,
32.125 ,33.125 , 34.125 , 33.34375, 34.34375, 35.34375, 35.125 , 36.125 ,
37.125 ,36.625 , 37.625 , 38.625 , 38.125 , 39.125 , 40.125 , 39.90625,
40.90625 ,41.90625, 41.125 , 42.125 , 43.125 , 41.40625, 42.40625, 43.40625,
37. ,38. , 39. , 38.21875, 39.21875, 40.21875, 40. , 41. ,
42. ,41.5 , 42.5 , 43.5 , 43. , 44. , 45. , 44.78125,
45.78125 ,46.78125, 46. , 47. , 48. , 46.28125, 47.28125, 48.28125,
38.125 ,39.125 , 40.125 , 39.34375, 40.34375, 41.34375, 41.125 , 42.125 ,
43.125 ,42.625 , 43.625 , 44.625 , 44.125 , 45.125 , 46.125 , 45.90625,
46.90625 ,47.90625, 47.125 , 48.125 , 49.125 , 47.40625, 48.40625, 49.40625,
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({8, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 8x8");
// expected.printBuffer("Expect for 8x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
NDArray input = NDArrayFactory::create<double>('c', {7, 7, 1}, {
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
8, 9.1, 10., 11, 12.9, 13.1, 14,
15, 16., 17., 18, 19, 20., 21,
22, 23., 24., 25, 26, 27, 28,
30, 31, 32, 33, 34., 35, 36,
37, 38, 39, 40, 41., 42, 43,
44, 45, 46, 47, 48., 49, 50
});
NDArray expected = NDArrayFactory::create<double>('c', {30, 30, 1}, {
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
functions::summarystats::SummaryStatsData<double> var1;
functions::summarystats::SummaryStatsData<double> var2;
var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5;
functions::summarystats::SummaryStatsData<double>* arr = new functions::summarystats::SummaryStatsData<double>[2];
arr[0] = var1;
arr[1] = var2;
arr[0] = arr[1];
functions::summarystats::SummaryStatsData<double> var3(var1);
ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5);
ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5);
ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0);
delete []arr;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 ,
155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdp = results->at(0);
auto dLdw = results->at(1);
auto dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {}, {4515.84});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96,
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208,
-2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152,
-1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304});
NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992,
-6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 });
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48,
-0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96});
NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296,
6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {188.16});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576,
-0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152});
NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552,
7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) {
auto x = NDArrayFactory::create<float>('c', {4}, {0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
nd4j::ops::squaredsubtract op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) {
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
nd4j::ops::squaredsubtract op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) {
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
nd4j::ops::squaredsubtract_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52,
12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdp = results->at(0);
auto dLdw = results->at(1);
auto dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {}, {288.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,
-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167});
NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04,
0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,
-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05});
NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 ,
-0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,
-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083});
NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48,
0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {12.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,
-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025});
NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576,
0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.,
-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167});
NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 });
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_1) {
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_2) {
NDArray x = NDArrayFactory::create<float16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_3) {
NDArray x('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray y('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548,
-6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077});
NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074,
-4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113});
NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18,
-0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {}, {-91.52109});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171,
-0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715});
NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539,
0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881});
NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105,
-0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805,
-0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258});
NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246,
0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258});
NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126,
-0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585,
-0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857});
NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 ,
-0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 });
NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525,
-0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {-3.81338});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902,
-0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629});
NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944,
-0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196});
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063,
-0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715});
NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,});
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.,
-0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_4) {
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_5) {
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(2, 2);
y.linspace(1);
exp.linspace(1);
nd4j::ops::subtract op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_6) {
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<double>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(2, 2);
y.linspace(1);
exp.linspace(1);
nd4j::ops::subtract op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) {
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326});
NDArray dLdwExp('c', {2}, {1.36729, 1.40729});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {1}, {1.38629});
logits = 2.;
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {}, {1.38629});
logits = 2.;
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04);
weights = 0.5;
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326});
NDArray dLdwExp('c', {1}, {1.36729});
logits.linspace(-0.08, 0.04);
weights = 0.5;
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) {
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951});
NDArray dLdwExp('c', {2}, {-0.014000, 0.014000});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3}, {0.5, 0., 1.5});
NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945,
0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945});
NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) {
NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,
0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,
0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,3,4,5}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335,
0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399,
0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866,
0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799,
0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901,
0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832,
0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768,
0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866,
0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901});
NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
// dLdp->printIndexedBuffer();
// ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
// ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) {
NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0});
auto sumDiff = labels.reduceAlongDims(reduce::Sum, {1}, true);
NDArray numOfNonZero(sumDiff.getShapeInfo(), nd4j::DataType::INT64, false);
numOfNonZero.assign(1);
sumDiff.applyPairwiseTransform(pairwise::SafeDivide, &numOfNonZero, &sumDiff, nullptr);
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0});
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519,
0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0});
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785,
0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {2,3}, {1,0,0, 0,1,1});
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
NDArray labels('c', {2,1}, {1,1});
NDArray logits('c', {2,1}, {-0.04, 0.04});
NDArray dLdpExp('c', {2,1}, {0., 0.});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
NDArray labels('c', {2,1}, {1,0});
NDArray logits('c', {2,1}, {-0.04, 0.04});
NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
NDArray labels('c', {1,2}, {1,1});
NDArray logits('c', {1,2}, {-0.04, 0.04});
NDArray dLdpExp('c', {1,2}, {0, 0});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
NDArray labels('c', {2}, {0,1});
NDArray logits('c', {2}, {-0.04, 0.04});
NDArray dLdpExp('c', {2}, {0.48001, -0.48001});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
NDArray labels('c', {1}, {1});
NDArray logits('c', {1}, {0.04});
NDArray dLdpExp('c', {1}, {0});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) {
NDArray x('c', {3,4,5}, nd4j::DataType::DOUBLE);
NDArray y('c', {1,1,1}, nd4j::DataType::DOUBLE);
NDArray dLdp('c', {3,4,5}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {3,4,5}, nd4j::DataType::DOUBLE);
x.assign(1.0);//linspace(0.1, 0.1);
y.assign(1.0);
dLdp.assign(1.0);
dLdpExp.assign(1.0);
nd4j::ops::multiply_bp op;
auto results = op.execute({&x, &y, &dLdp}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdo = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdo));
ASSERT_TRUE(dLdpExp.equalsTo(dLdo));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) {
NDArray labels('c', {2}, {2,1}, nd4j::DataType::INT64);
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717});
logits.linspace(0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
NDArray labels('c', {2}, {0,1}, nd4j::DataType::INT64);
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717});
logits.linspace(-0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
NDArray logits('c', {2}, {-0.2, 0.3});
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, nd4j::DataType::INT64);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865,
0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865});
logits.linspace(-0.5, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
NDArray labels('c', {1,1}, {0}, nd4j::DataType::INT64);
NDArray logits('c', {1,1,2}, {-0.3,0.2});
NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246});
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}