cavis/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp

3121 lines
130 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 8/4/2018.
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
using namespace nd4j;
class DeclarableOpsTests11 : public testing::Test {
public:
DeclarableOpsTests11() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
nd4j::ops::listdiff op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444,
-0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {}, {-227.77286});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
// dLdw->printIndexedBuffer();
// dLdw->printShapeInfo();
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308,
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541,
0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698});
NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334,
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769,
-2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991});
NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 ,
1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004,
-0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154,
-1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996});
NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727,
-0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393});
NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167,
-0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {-9.49054});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846,
-1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956});
NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072,
-0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 ,
-0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::log_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 7, 7, 1}, {
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
8, 9.1, 10., 11, 12.9, 13.1, 14,
15, 16., 17., 18, 19, 20., 21,
22, 23., 24., 25, 26, 27, 28,
30, 31, 32, 33, 34., 35, 36,
37, 38, 39, 40, 41., 42, 43,
44, 45, 46, 47, 48., 49, 50
});
NDArray expected = NDArrayFactory::create<double>('c', {1, 30, 30, 1}, {
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {2, 10, 8, 3}, {
1. , 2. ,3. ,2.21875, 3.21875, 4.21875, 4. , 5. , 6. ,5.5,
6.5, 7.5, 7., 8., 9. ,8.78125, 9.78125, 10.78125, 10., 11. ,
12., 10.28125 , 11.28125 ,12.28125, 5.875, 6.875, 7.875, 7.09375, 8.09375 ,9.09375,
8.875, 9.875, 10.875, 10.375, 11.375, 12.375 ,11.875 ,12.875 , 13.875, 13.65625,
14.65625, 15.65625, 14.875 ,15.875 ,16.875 , 15.15625, 16.15625, 17.15625, 13., 14.,
15. ,14.21875, 15.21875, 16.21875, 16., 17., 18. ,17.5 ,18.5 , 19.5,
19., 20., 21., 20.78125 ,21.78125 ,22.78125, 22., 23. , 24. , 22.28125,
23.28125 ,24.28125 ,19. , 20., 21., 20.21875, 21.21875, 22.21875 ,22. ,23.,
24. , 23.5, 24.5, 25.5, 25. ,26. ,27., 26.78125 , 27.78125, 28.78125,
28., 29. ,30. ,28.28125, 29.28125, 30.28125, 25., 26., 27. ,26.21875,
27.21875, 28.21875, 28., 29., 30., 29.5 ,30.5 ,31.5 , 31., 32.,
33., 32.78125, 33.78125 ,34.78125 ,34., 35., 36., 34.28125, 35.28125, 36.28125,
31. ,32., 33. , 32.21875, 33.21875, 34.21875, 34. ,35. ,36., 35.5,
36.5 , 37.5 , 37., 38. ,39. ,38.78125, 39.78125, 40.78125, 40., 41.,
42. ,40.28125 ,41.28125, 42.28125, 37. , 38., 39., 38.21875 ,39.21875 ,40.21875,
40. , 41. , 42. , 41.5, 42.5, 43.5 ,43., 44., 45., 44.78125,
45.78125, 46.78125 ,46. ,47. , 48. , 46.28125 , 47.28125, 48.28125, 44.125 ,45.125,
46.125, 45.34375, 46.34375, 47.34375, 47.125, 48.125 ,49.125 ,48.625, 49.625 , 50.625,
50.125 , 51.125, 52.125 ,51.90625 ,52.90625, 53.90625, 53.125, 54.125, 55.125, 53.40625,
54.40625 ,55.40625, 49. ,50. , 51. ,50.21875, 51.21875 ,52.21875 ,52. ,53.,
54. ,53.5 , 54.5, 55.5 ,55. ,56. ,57. ,56.78125 ,57.78125, 58.78125,
58. ,59. ,60. ,58.28125 ,59.28125 ,60.28125, 50.125, 51.125 ,52.125 ,51.34375,
52.34375 ,53.34375 ,53.125, 54.125, 55.125 ,54.625 ,55.625 ,56.625 ,56.125 ,57.125,
58.125, 57.90625 ,58.90625 ,59.90625 ,59.125 ,60.125 ,61.125, 59.40625, 60.40625 ,61.40625,
61. ,62. ,63. ,62.21875, 63.21875, 64.21875 ,64. ,65. ,66. ,65.5 ,
66.5, 67.5, 67. ,68. ,69. ,68.78125 ,69.78125 ,70.78125 ,70., 71. ,
72. ,70.28125 ,71.28125 ,72.28125 ,65.875 ,66.875, 67.875 ,67.09375 ,68.09375 ,69.09375,
68.875 ,69.875 ,70.875, 70.375 ,71.375 ,72.375 ,71.875 ,72.875 ,73.875 ,73.65625,
74.65625 ,75.65625 ,74.875 ,75.875, 76.875 ,75.15625 ,76.15625,
77.15625 ,73. ,74. ,75., 74.21875 ,75.21875 ,76.21875,
76. ,77. ,78. ,77.5 ,78.5 ,79.5 ,79.,
80. ,81. ,80.78125 ,81.78125, 82.78125 ,82. ,83.,
84. ,82.28125 ,83.28125 ,84.28125, 79. ,80. ,81.,
80.21875 ,81.21875 ,82.21875 ,82., 83. ,84. ,83.5,
84.5 ,85.5 ,85. ,86., 87. ,86.78125 ,87.78125,
88.78125 ,88. ,89. ,90., 88.28125 ,89.28125 ,90.28125,
85. ,86. ,87. ,86.21875, 87.21875 ,88.21875 ,88.,
89. ,90. ,89.5 ,90.5, 91.5 ,91. ,92.,
93. ,92.78125 ,93.78125 ,94.78125, 94. ,95. ,96.,
94.28125 ,95.28125 ,96.28125 ,91., 92. ,93. ,92.21875,
93.21875 ,94.21875 ,94. ,95., 96. ,95.5 ,96.5,
97.5 ,97. ,98. ,99., 98.78125 ,99.78125 ,100.78125,
100. ,101. ,102. ,100.28125, 101.28125 ,102.28125, 97.,
98. ,99. ,98.21875 ,99.21875, 100.21875 ,100., 101.,
102. ,101.5 ,102.5 ,103.5, 103. ,104., 105.,
104.78125 ,105.78125 ,106.78125 ,106., 107. ,108., 106.28125,
107.28125 ,108.28125 ,104.125 ,105.125, 106.125 ,105.34375, 106.34375,
107.34375 ,107.125 ,108.125 ,109.125, 108.625 ,109.625, 110.625,
110.125 ,111.125 ,112.125 ,111.90625, 112.90625 ,113.90625, 113.125,
114.125 ,115.125 ,113.40625 ,114.40625, 115.40625 ,109., 110.,
111. ,110.21875 ,111.21875 ,112.21875, 112., 113., 114.,
113.5 ,114.5 ,115.5 ,115., 116., 117., 116.78125,
117.78125 ,118.78125 ,118. ,119., 120., 118.28125, 119.28125,
120.28125 ,110.125 ,111.125 ,112.125, 111.34375, 112.34375, 113.34375,
113.125 ,114.125 ,115.125 ,114.625, 115.625, 116.625, 116.125,
117.125 ,118.125 ,117.90625, 118.90625, 119.90625, 119.125, 120.125,
121.125 ,119.40625 ,120.40625, 121.40625}); //input = 1.f;
input.linspace(1);
auto size = NDArrayFactory::create<int>({10, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 10x8");
// expected.printBuffer("Expect for 10x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 6, 4}, {
1. ,2. ,3. ,4.,
2.625 ,3.625 ,4.625 ,5.625,
5. ,6. ,7. ,8.,
7.375 ,8.375 ,9.375, 10.375,
9. ,10. ,11. ,12.,
9.375 ,10.375 ,11.375 ,12.375,
5.875 ,6.875 ,7.875 , 8.875 ,
7.5 ,8.5 ,9.5 , 10.5 ,
9.875 ,10.875 ,11.875, 12.875,
12.25 ,13.25 ,14.25 , 15.25 ,
13.875 ,14.875 ,15.875, 16.875,
14.25 ,15.25 ,16.25 , 17.25 ,
13. ,14. ,15. ,16.,
14.625 ,15.625 ,16.625 ,17.625,
17. ,18. ,19. ,20.,
19.375 ,20.375 ,21.375 ,22.375,
21. ,22. ,23. ,24.,
21.375 ,22.375 ,23.375 ,24.375,
20.125 ,21.125 ,22.125 ,23.125,
21.75 ,22.75 ,23.75 ,24.75,
24.125 ,25.125 ,26.125 ,27.125,
26.5 ,27.5 ,28.5 ,29.5,
28.125 ,29.125 ,30.125 ,31.125,
28.5 ,29.5 ,30.5 ,31.5,
25. , 26. , 27. , 28.,
26.625 ,27.625 ,28.625 ,29.625,
29. ,30. ,31. ,32.,
31.375 ,32.375 ,33.375 ,34.375,
33. ,34. ,35. ,36.,
33.375 ,34.375 ,35.375 ,36.375,
26.125, 27.125, 28.125, 29.125,
27.75 ,28.75 ,29.75 ,30.75,
30.125 ,31.125 ,32.125 ,33.125,
32.5 ,33.5 ,34.5 ,35.5,
34.125 ,35.125 ,36.125 ,37.125,
34.5 ,35.5 ,36.5 ,37.5
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x6");
// expected.printBuffer("Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 8, 3}, {
1. , 2. , 3. ,
2.21875 ,3.21875 ,4.21875,
4. ,5. ,6. ,
5.5 ,6.5 ,7.5 ,
7. ,8. ,9. ,
8.78125 ,9.78125, 10.78125,
10. ,11., 12. ,
10.28125 ,11.28125, 12.28125,
5.875 , 6.875 , 7.875 ,
7.09375 , 8.09375 , 9.09375,
8.875 , 9.875 ,10.875 ,
10.375 ,11.375 ,12.375 ,
11.875 ,12.875 ,13.875 ,
13.65625 ,14.65625 ,15.65625,
14.875 ,15.875 ,16.875 ,
15.15625 ,16.15625 ,17.15625,
13., 14., 15.,
14.21875 ,15.21875 ,16.21875,
16. ,17. ,18. ,
17.5 ,18.5 ,19.5 ,
19. ,20. ,21. ,
20.78125 ,21.78125 ,22.78125,
22. ,23. ,24. ,
22.28125 ,23.28125 ,24.28125,
20.125 , 21.125 , 22.125,
21.34375 ,22.34375 ,23.34375,
23.125 ,24.125 ,25.125 ,
24.625 ,25.625 ,26.625 ,
26.125 ,27.125 ,28.125 ,
27.90625 ,28.90625 ,29.90625,
29.125 ,30.125 ,31.125 ,
29.40625 ,30.40625 ,31.40625,
25. ,26. ,27. ,
26.21875 ,27.21875 ,28.21875,
28. ,29. ,30. ,
29.5 ,30.5 ,31.5 ,
31. ,32. ,33. ,
32.78125 ,33.78125 ,34.78125,
34. ,35. ,36. ,
34.28125 ,35.28125 ,36.28125,
26.125 ,27.125 , 28.125 ,
27.34375 ,28.34375 ,29.34375,
29.125 ,30.125 ,31.125 ,
30.625 ,31.625 ,32.625 ,
32.125 ,33.125 ,34.125 ,
33.90625 ,34.90625 ,35.90625,
35.125 ,36.125 ,37.125 ,
35.40625 ,36.40625 ,37.40625 });
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x8");
// expected.printBuffer("Expect for 6x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
NDArray input = NDArrayFactory::create<double>('c', {1, 4, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {1, 8, 8, 3}, {
1. ,2. , 3. , 2.21875 , 3.21875 , 4.21875 , 4. , 5. ,
6. ,5.5 , 6.5 , 7.5 , 7. , 8. , 9. , 8.78125 ,
9.78125 ,10.78125 ,10. ,11. ,12. ,10.28125 ,11.28125 ,12.28125 ,
5.875 ,6.875 , 7.875 , 7.09375 , 8.09375 , 9.09375 , 8.875 , 9.875 ,
10.875 ,10.375 , 11.375 , 12.375 , 11.875 , 12.875 , 13.875 , 13.65625,
14.65625 ,15.65625, 14.875 , 15.875 , 16.875 , 15.15625, 16.15625, 17.15625,
13. ,14. , 15. , 14.21875, 15.21875, 16.21875, 16. , 17. ,
18. ,17.5 , 18.5 , 19.5 , 19. , 20. , 21. , 20.78125,
21.78125 ,22.78125, 22. , 23. , 24. , 22.28125, 23.28125, 24.28125,
19. ,20. , 21. , 20.21875, 21.21875, 22.21875, 22. , 23. ,
24. ,23.5 , 24.5 , 25.5 , 25. , 26. , 27. , 26.78125,
27.78125 ,28.78125, 28. , 29. , 30. , 28.28125, 29.28125, 30.28125,
25. ,26. , 27. , 26.21875, 27.21875, 28.21875, 28. , 29. ,
30. ,29.5 , 30.5 , 31.5 , 31. , 32. , 33. , 32.78125,
33.78125 ,34.78125, 34. , 35. , 36. , 34.28125, 35.28125, 36.28125,
32.125 ,33.125 , 34.125 , 33.34375, 34.34375, 35.34375, 35.125 , 36.125 ,
37.125 ,36.625 , 37.625 , 38.625 , 38.125 , 39.125 , 40.125 , 39.90625,
40.90625 ,41.90625, 41.125 , 42.125 , 43.125 , 41.40625, 42.40625, 43.40625,
37. ,38. , 39. , 38.21875, 39.21875, 40.21875, 40. , 41. ,
42. ,41.5 , 42.5 , 43.5 , 43. , 44. , 45. , 44.78125,
45.78125 ,46.78125, 46. , 47. , 48. , 46.28125, 47.28125, 48.28125,
38.125 ,39.125 , 40.125 , 39.34375, 40.34375, 41.34375, 41.125 , 42.125 ,
43.125 ,42.625 , 43.625 , 44.625 , 44.125 , 45.125 , 46.125 , 45.90625,
46.90625 ,47.90625, 47.125 , 48.125 , 49.125 , 47.40625, 48.40625, 49.40625,
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({8, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 8x8");
// expected.printBuffer("Expect for 8x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
NDArray input = NDArrayFactory::create<double>('c', {7, 7, 1}, {
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
8, 9.1, 10., 11, 12.9, 13.1, 14,
15, 16., 17., 18, 19, 20., 21,
22, 23., 24., 25, 26, 27, 28,
30, 31, 32, 33, 34., 35, 36,
37, 38, 39, 40, 41., 42, 43,
44, 45, 46, 47, 48., 49, 50
});
NDArray expected = NDArrayFactory::create<double>('c', {30, 30, 1}, {
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
functions::summarystats::SummaryStatsData<double> var1;
functions::summarystats::SummaryStatsData<double> var2;
var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5;
functions::summarystats::SummaryStatsData<double>* arr = new functions::summarystats::SummaryStatsData<double>[2];
arr[0] = var1;
arr[1] = var2;
arr[0] = arr[1];
functions::summarystats::SummaryStatsData<double> var3(var1);
ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5);
ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5);
ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0);
delete []arr;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 ,
155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdp = results->at(0);
auto dLdw = results->at(1);
auto dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {}, {4515.84});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96,
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208,
-2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152,
-1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304});
NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992,
-6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 });
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48,
-0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96});
NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296,
6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {188.16});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576,
-0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152});
NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552,
7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) {
auto x = NDArrayFactory::create<float>('c', {4}, {0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
nd4j::ops::squaredsubtract op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) {
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
nd4j::ops::squaredsubtract op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) {
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
nd4j::ops::squaredsubtract_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52,
12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdp = results->at(0);
auto dLdw = results->at(1);
auto dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {}, {288.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,
-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167});
NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04,
0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,
-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05});
NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 ,
-0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,
-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083});
NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48,
0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {12.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,
-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025});
NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576,
0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.,
-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167});
NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 });
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_1) {
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_2) {
NDArray x = NDArrayFactory::create<float16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_3) {
NDArray x('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray y('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548,
-6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077});
NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074,
-4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113});
NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18,
-0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {}, {-91.52109});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171,
-0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715});
NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539,
0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881});
NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105,
-0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805,
-0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258});
NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246,
0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258});
NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126,
-0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585,
-0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857});
NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 ,
-0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 });
NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525,
-0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {-3.81338});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902,
-0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629});
NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944,
-0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196});
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063,
-0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715});
NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,});
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.,
-0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_4) {
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_5) {
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(2, 2);
y.linspace(1);
exp.linspace(1);
nd4j::ops::subtract op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_6) {
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<double>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(2, 2);
y.linspace(1);
exp.linspace(1);
nd4j::ops::subtract op;
auto results = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) {
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326});
NDArray dLdwExp('c', {2}, {1.36729, 1.40729});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {1}, {1.38629});
logits = 2.;
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {}, {1.38629});
logits = 2.;
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04);
weights = 0.5;
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326});
NDArray dLdwExp('c', {1}, {1.36729});
logits.linspace(-0.08, 0.04);
weights = 0.5;
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) {
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951});
NDArray dLdwExp('c', {2}, {-0.014000, 0.014000});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3}, {0.5, 0., 1.5});
NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945,
0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945});
NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) {
NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,
0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,
0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,3,4,5}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335,
0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399,
0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866,
0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799,
0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901,
0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832,
0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768,
0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866,
0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901});
NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
// dLdp->printIndexedBuffer();
// ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
// ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) {
NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0});
auto sumDiff = labels.reduceAlongDims(reduce::Sum, {1}, true);
NDArray numOfNonZero(sumDiff.getShapeInfo(), nd4j::DataType::INT64, false);
numOfNonZero.assign(1);
sumDiff.applyPairwiseTransform(pairwise::SafeDivide, &numOfNonZero, &sumDiff, nullptr);
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0});
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519,
0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0});
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785,
0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {2,3}, {1,0,0, 0,1,1});
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004});
logits.linspace(-0.08, 0.04);
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
NDArray labels('c', {2,1}, {1,1});
NDArray logits('c', {2,1}, {-0.04, 0.04});
NDArray dLdpExp('c', {2,1}, {0., 0.});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
NDArray labels('c', {2,1}, {1,0});
NDArray logits('c', {2,1}, {-0.04, 0.04});
NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
NDArray labels('c', {1,2}, {1,1});
NDArray logits('c', {1,2}, {-0.04, 0.04});
NDArray dLdpExp('c', {1,2}, {0, 0});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
NDArray labels('c', {2}, {0,1});
NDArray logits('c', {2}, {-0.04, 0.04});
NDArray dLdpExp('c', {2}, {0.48001, -0.48001});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
NDArray labels('c', {1}, {1});
NDArray logits('c', {1}, {0.04});
NDArray dLdpExp('c', {1}, {0});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&logits, &labels}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) {
NDArray x('c', {3,4,5}, nd4j::DataType::DOUBLE);
NDArray y('c', {1,1,1}, nd4j::DataType::DOUBLE);
NDArray dLdp('c', {3,4,5}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {3,4,5}, nd4j::DataType::DOUBLE);
x.assign(1.0);//linspace(0.1, 0.1);
y.assign(1.0);
dLdp.assign(1.0);
dLdpExp.assign(1.0);
nd4j::ops::multiply_bp op;
auto results = op.execute({&x, &y, &dLdp}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdo = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdo));
ASSERT_TRUE(dLdpExp.equalsTo(dLdo));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) {
NDArray labels('c', {2}, {2,1}, nd4j::DataType::INT64);
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717});
logits.linspace(0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
NDArray labels('c', {2}, {0,1}, nd4j::DataType::INT64);
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717});
logits.linspace(-0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
NDArray logits('c', {2}, {-0.2, 0.3});
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, nd4j::DataType::INT64);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865,
0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865});
logits.linspace(-0.5, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
NDArray labels('c', {1,1}, {0}, nd4j::DataType::INT64);
NDArray logits('c', {1,1,2}, {-0.3,0.2});
NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246});
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.execute({&labels, &logits}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}