3121 lines
130 KiB
C++
3121 lines
130 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
|
|
//
|
|
// Created by raver on 8/4/2018.
|
|
//
|
|
|
|
#include "testlayers.h"
|
|
#include <ops/declarable/CustomOperations.h>
|
|
#include <NDArray.h>
|
|
#include <ops/ops.h>
|
|
#include <GradCheck.h>
|
|
|
|
using namespace nd4j;
|
|
|
|
|
|
class DeclarableOpsTests11 : public testing::Test {
|
|
public:
|
|
|
|
DeclarableOpsTests11() {
|
|
printf("\n");
|
|
fflush(stdout);
|
|
}
|
|
};
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
|
|
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});
|
|
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
|
|
|
|
nd4j::ops::listdiff op;
|
|
auto result = op.execute({&x, &y}, {}, {});
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
|
|
delete result;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test1) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
|
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
|
NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444,
|
|
-0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434});
|
|
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
|
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test2) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
|
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
|
NDArray dLdwExp('c', {}, {-227.77286});
|
|
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
|
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test4) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
// dLdw->printIndexedBuffer();
|
|
// dLdw->printShapeInfo();
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test5) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308,
|
|
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
|
|
NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541,
|
|
0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698});
|
|
NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334,
|
|
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test6) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {}, {0.});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test8) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769,
|
|
-2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991});
|
|
NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 ,
|
|
1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107});
|
|
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004,
|
|
-0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.p(0, 0.);
|
|
weights.p(1, 0.);
|
|
weights.p(2, 0.);
|
|
weights.p(3, 0.);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test9) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154,
|
|
-1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727,
|
|
-0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393});
|
|
NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167,
|
|
-0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,1}, {-9.49054});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test11) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846,
|
|
-1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072,
|
|
-0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272});
|
|
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 ,
|
|
-0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945});
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
weights.t<double>(3) = 0.;
|
|
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
|
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
|
|
NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383});
|
|
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
|
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
|
|
nd4j::ops::log_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 7, 7, 1}, {
|
|
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
|
|
8, 9.1, 10., 11, 12.9, 13.1, 14,
|
|
15, 16., 17., 18, 19, 20., 21,
|
|
22, 23., 24., 25, 26, 27, 28,
|
|
30, 31, 32, 33, 34., 35, 36,
|
|
37, 38, 39, 40, 41., 42, 43,
|
|
44, 45, 46, 47, 48., 49, 50
|
|
});
|
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 30, 30, 1}, {
|
|
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
|
|
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
|
|
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
|
|
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
|
|
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
|
|
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
|
|
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
|
|
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
|
|
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
|
|
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
|
|
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
|
|
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
|
|
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
|
|
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
|
|
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
|
|
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
|
|
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
|
|
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
|
|
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
|
|
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
|
|
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
|
|
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
|
|
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
|
|
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
|
|
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
|
|
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
|
|
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
|
|
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
|
|
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
|
|
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
|
|
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
|
|
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
|
|
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
|
|
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
|
|
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
|
|
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
|
|
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
|
|
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
|
|
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
|
|
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
|
|
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
|
|
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
|
|
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
|
|
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
|
|
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
|
|
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
|
|
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
|
|
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
|
|
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
|
|
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
|
|
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
|
|
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
|
|
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
|
|
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
|
|
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
|
|
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
|
|
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
|
|
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
|
|
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
|
|
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
|
|
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
|
|
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
|
|
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
|
|
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
|
|
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
|
|
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
|
|
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
|
|
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
|
|
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
|
|
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
|
|
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
|
|
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
|
|
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
|
|
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
|
|
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
|
|
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
|
|
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
|
|
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
|
|
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
|
|
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
|
|
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
|
|
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
|
|
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
|
|
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
|
|
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
|
|
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
|
|
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
|
|
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
|
|
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
|
|
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
|
|
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
|
|
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
|
|
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
|
|
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
|
|
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
|
|
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
|
|
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
|
|
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
|
|
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
|
|
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
|
|
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
|
|
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
|
|
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
|
|
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
|
|
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
|
|
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
|
|
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
|
|
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
|
|
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
|
|
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
|
|
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
|
|
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
|
|
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
|
|
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
|
|
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
|
|
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
|
|
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
|
|
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
|
|
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
|
|
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
|
|
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
|
|
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
|
|
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
|
|
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
|
|
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
|
|
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
|
|
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
|
|
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
|
|
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
|
|
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
|
|
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
|
|
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
|
|
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
|
|
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
|
|
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
|
|
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
|
|
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
|
|
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
|
|
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
|
|
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
|
|
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
|
|
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
|
|
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
|
|
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
|
|
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
|
|
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
|
|
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
|
|
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
|
|
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
|
|
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
|
|
|
|
auto size = NDArrayFactory::create<int>({30, 30});
|
|
nd4j::ops::resize_bicubic op;
|
|
auto results = op.execute({&input, &size}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Resized to 30x30");
|
|
// expected.printBuffer("Expect for 30x30");
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
delete results;
|
|
}
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 4, 3});
|
|
NDArray expected = NDArrayFactory::create<double>('c', {2, 10, 8, 3}, {
|
|
1. , 2. ,3. ,2.21875, 3.21875, 4.21875, 4. , 5. , 6. ,5.5,
|
|
6.5, 7.5, 7., 8., 9. ,8.78125, 9.78125, 10.78125, 10., 11. ,
|
|
12., 10.28125 , 11.28125 ,12.28125, 5.875, 6.875, 7.875, 7.09375, 8.09375 ,9.09375,
|
|
8.875, 9.875, 10.875, 10.375, 11.375, 12.375 ,11.875 ,12.875 , 13.875, 13.65625,
|
|
14.65625, 15.65625, 14.875 ,15.875 ,16.875 , 15.15625, 16.15625, 17.15625, 13., 14.,
|
|
15. ,14.21875, 15.21875, 16.21875, 16., 17., 18. ,17.5 ,18.5 , 19.5,
|
|
19., 20., 21., 20.78125 ,21.78125 ,22.78125, 22., 23. , 24. , 22.28125,
|
|
23.28125 ,24.28125 ,19. , 20., 21., 20.21875, 21.21875, 22.21875 ,22. ,23.,
|
|
24. , 23.5, 24.5, 25.5, 25. ,26. ,27., 26.78125 , 27.78125, 28.78125,
|
|
28., 29. ,30. ,28.28125, 29.28125, 30.28125, 25., 26., 27. ,26.21875,
|
|
27.21875, 28.21875, 28., 29., 30., 29.5 ,30.5 ,31.5 , 31., 32.,
|
|
33., 32.78125, 33.78125 ,34.78125 ,34., 35., 36., 34.28125, 35.28125, 36.28125,
|
|
31. ,32., 33. , 32.21875, 33.21875, 34.21875, 34. ,35. ,36., 35.5,
|
|
36.5 , 37.5 , 37., 38. ,39. ,38.78125, 39.78125, 40.78125, 40., 41.,
|
|
42. ,40.28125 ,41.28125, 42.28125, 37. , 38., 39., 38.21875 ,39.21875 ,40.21875,
|
|
40. , 41. , 42. , 41.5, 42.5, 43.5 ,43., 44., 45., 44.78125,
|
|
45.78125, 46.78125 ,46. ,47. , 48. , 46.28125 , 47.28125, 48.28125, 44.125 ,45.125,
|
|
46.125, 45.34375, 46.34375, 47.34375, 47.125, 48.125 ,49.125 ,48.625, 49.625 , 50.625,
|
|
50.125 , 51.125, 52.125 ,51.90625 ,52.90625, 53.90625, 53.125, 54.125, 55.125, 53.40625,
|
|
54.40625 ,55.40625, 49. ,50. , 51. ,50.21875, 51.21875 ,52.21875 ,52. ,53.,
|
|
54. ,53.5 , 54.5, 55.5 ,55. ,56. ,57. ,56.78125 ,57.78125, 58.78125,
|
|
58. ,59. ,60. ,58.28125 ,59.28125 ,60.28125, 50.125, 51.125 ,52.125 ,51.34375,
|
|
52.34375 ,53.34375 ,53.125, 54.125, 55.125 ,54.625 ,55.625 ,56.625 ,56.125 ,57.125,
|
|
58.125, 57.90625 ,58.90625 ,59.90625 ,59.125 ,60.125 ,61.125, 59.40625, 60.40625 ,61.40625,
|
|
61. ,62. ,63. ,62.21875, 63.21875, 64.21875 ,64. ,65. ,66. ,65.5 ,
|
|
66.5, 67.5, 67. ,68. ,69. ,68.78125 ,69.78125 ,70.78125 ,70., 71. ,
|
|
72. ,70.28125 ,71.28125 ,72.28125 ,65.875 ,66.875, 67.875 ,67.09375 ,68.09375 ,69.09375,
|
|
68.875 ,69.875 ,70.875, 70.375 ,71.375 ,72.375 ,71.875 ,72.875 ,73.875 ,73.65625,
|
|
74.65625 ,75.65625 ,74.875 ,75.875, 76.875 ,75.15625 ,76.15625,
|
|
77.15625 ,73. ,74. ,75., 74.21875 ,75.21875 ,76.21875,
|
|
76. ,77. ,78. ,77.5 ,78.5 ,79.5 ,79.,
|
|
80. ,81. ,80.78125 ,81.78125, 82.78125 ,82. ,83.,
|
|
84. ,82.28125 ,83.28125 ,84.28125, 79. ,80. ,81.,
|
|
80.21875 ,81.21875 ,82.21875 ,82., 83. ,84. ,83.5,
|
|
84.5 ,85.5 ,85. ,86., 87. ,86.78125 ,87.78125,
|
|
88.78125 ,88. ,89. ,90., 88.28125 ,89.28125 ,90.28125,
|
|
85. ,86. ,87. ,86.21875, 87.21875 ,88.21875 ,88.,
|
|
89. ,90. ,89.5 ,90.5, 91.5 ,91. ,92.,
|
|
93. ,92.78125 ,93.78125 ,94.78125, 94. ,95. ,96.,
|
|
94.28125 ,95.28125 ,96.28125 ,91., 92. ,93. ,92.21875,
|
|
93.21875 ,94.21875 ,94. ,95., 96. ,95.5 ,96.5,
|
|
97.5 ,97. ,98. ,99., 98.78125 ,99.78125 ,100.78125,
|
|
100. ,101. ,102. ,100.28125, 101.28125 ,102.28125, 97.,
|
|
98. ,99. ,98.21875 ,99.21875, 100.21875 ,100., 101.,
|
|
102. ,101.5 ,102.5 ,103.5, 103. ,104., 105.,
|
|
104.78125 ,105.78125 ,106.78125 ,106., 107. ,108., 106.28125,
|
|
107.28125 ,108.28125 ,104.125 ,105.125, 106.125 ,105.34375, 106.34375,
|
|
107.34375 ,107.125 ,108.125 ,109.125, 108.625 ,109.625, 110.625,
|
|
110.125 ,111.125 ,112.125 ,111.90625, 112.90625 ,113.90625, 113.125,
|
|
114.125 ,115.125 ,113.40625 ,114.40625, 115.40625 ,109., 110.,
|
|
111. ,110.21875 ,111.21875 ,112.21875, 112., 113., 114.,
|
|
113.5 ,114.5 ,115.5 ,115., 116., 117., 116.78125,
|
|
117.78125 ,118.78125 ,118. ,119., 120., 118.28125, 119.28125,
|
|
120.28125 ,110.125 ,111.125 ,112.125, 111.34375, 112.34375, 113.34375,
|
|
113.125 ,114.125 ,115.125 ,114.625, 115.625, 116.625, 116.125,
|
|
117.125 ,118.125 ,117.90625, 118.90625, 119.90625, 119.125, 120.125,
|
|
121.125 ,119.40625 ,120.40625, 121.40625}); //input = 1.f;
|
|
input.linspace(1);
|
|
auto size = NDArrayFactory::create<int>({10, 8});
|
|
nd4j::ops::resize_bicubic op;
|
|
auto results = op.execute({&input, &size}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Resized to 10x8");
|
|
// expected.printBuffer("Expect for 10x8");
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
|
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 6, 4}, {
|
|
1. ,2. ,3. ,4.,
|
|
2.625 ,3.625 ,4.625 ,5.625,
|
|
5. ,6. ,7. ,8.,
|
|
7.375 ,8.375 ,9.375, 10.375,
|
|
9. ,10. ,11. ,12.,
|
|
9.375 ,10.375 ,11.375 ,12.375,
|
|
|
|
5.875 ,6.875 ,7.875 , 8.875 ,
|
|
7.5 ,8.5 ,9.5 , 10.5 ,
|
|
9.875 ,10.875 ,11.875, 12.875,
|
|
12.25 ,13.25 ,14.25 , 15.25 ,
|
|
13.875 ,14.875 ,15.875, 16.875,
|
|
14.25 ,15.25 ,16.25 , 17.25 ,
|
|
|
|
13. ,14. ,15. ,16.,
|
|
14.625 ,15.625 ,16.625 ,17.625,
|
|
17. ,18. ,19. ,20.,
|
|
19.375 ,20.375 ,21.375 ,22.375,
|
|
21. ,22. ,23. ,24.,
|
|
21.375 ,22.375 ,23.375 ,24.375,
|
|
|
|
20.125 ,21.125 ,22.125 ,23.125,
|
|
21.75 ,22.75 ,23.75 ,24.75,
|
|
24.125 ,25.125 ,26.125 ,27.125,
|
|
26.5 ,27.5 ,28.5 ,29.5,
|
|
28.125 ,29.125 ,30.125 ,31.125,
|
|
28.5 ,29.5 ,30.5 ,31.5,
|
|
|
|
25. , 26. , 27. , 28.,
|
|
26.625 ,27.625 ,28.625 ,29.625,
|
|
29. ,30. ,31. ,32.,
|
|
31.375 ,32.375 ,33.375 ,34.375,
|
|
33. ,34. ,35. ,36.,
|
|
33.375 ,34.375 ,35.375 ,36.375,
|
|
|
|
26.125, 27.125, 28.125, 29.125,
|
|
27.75 ,28.75 ,29.75 ,30.75,
|
|
30.125 ,31.125 ,32.125 ,33.125,
|
|
32.5 ,33.5 ,34.5 ,35.5,
|
|
34.125 ,35.125 ,36.125 ,37.125,
|
|
34.5 ,35.5 ,36.5 ,37.5
|
|
});
|
|
input.linspace(1);
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
|
nd4j::ops::resize_bicubic op;
|
|
auto results = op.execute({&input, &size}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Resized to 6x6");
|
|
// expected.printBuffer("Expect for 6x6");
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 4, 3});
|
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 8, 3}, {
|
|
1. , 2. , 3. ,
|
|
2.21875 ,3.21875 ,4.21875,
|
|
4. ,5. ,6. ,
|
|
5.5 ,6.5 ,7.5 ,
|
|
7. ,8. ,9. ,
|
|
8.78125 ,9.78125, 10.78125,
|
|
10. ,11., 12. ,
|
|
10.28125 ,11.28125, 12.28125,
|
|
|
|
5.875 , 6.875 , 7.875 ,
|
|
7.09375 , 8.09375 , 9.09375,
|
|
8.875 , 9.875 ,10.875 ,
|
|
10.375 ,11.375 ,12.375 ,
|
|
11.875 ,12.875 ,13.875 ,
|
|
13.65625 ,14.65625 ,15.65625,
|
|
14.875 ,15.875 ,16.875 ,
|
|
15.15625 ,16.15625 ,17.15625,
|
|
|
|
13., 14., 15.,
|
|
14.21875 ,15.21875 ,16.21875,
|
|
16. ,17. ,18. ,
|
|
17.5 ,18.5 ,19.5 ,
|
|
19. ,20. ,21. ,
|
|
20.78125 ,21.78125 ,22.78125,
|
|
22. ,23. ,24. ,
|
|
22.28125 ,23.28125 ,24.28125,
|
|
|
|
20.125 , 21.125 , 22.125,
|
|
21.34375 ,22.34375 ,23.34375,
|
|
23.125 ,24.125 ,25.125 ,
|
|
24.625 ,25.625 ,26.625 ,
|
|
26.125 ,27.125 ,28.125 ,
|
|
27.90625 ,28.90625 ,29.90625,
|
|
29.125 ,30.125 ,31.125 ,
|
|
29.40625 ,30.40625 ,31.40625,
|
|
|
|
25. ,26. ,27. ,
|
|
26.21875 ,27.21875 ,28.21875,
|
|
28. ,29. ,30. ,
|
|
29.5 ,30.5 ,31.5 ,
|
|
31. ,32. ,33. ,
|
|
32.78125 ,33.78125 ,34.78125,
|
|
34. ,35. ,36. ,
|
|
34.28125 ,35.28125 ,36.28125,
|
|
|
|
26.125 ,27.125 , 28.125 ,
|
|
27.34375 ,28.34375 ,29.34375,
|
|
29.125 ,30.125 ,31.125 ,
|
|
30.625 ,31.625 ,32.625 ,
|
|
32.125 ,33.125 ,34.125 ,
|
|
33.90625 ,34.90625 ,35.90625,
|
|
35.125 ,36.125 ,37.125 ,
|
|
35.40625 ,36.40625 ,37.40625 });
|
|
input.linspace(1);
|
|
auto size = NDArrayFactory::create<int>({6, 8});
|
|
nd4j::ops::resize_bicubic op;
|
|
auto results = op.execute({&input, &size}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Resized to 6x8");
|
|
// expected.printBuffer("Expect for 6x8");
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 4, 4, 3});
|
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 8, 8, 3}, {
|
|
1. ,2. , 3. , 2.21875 , 3.21875 , 4.21875 , 4. , 5. ,
|
|
6. ,5.5 , 6.5 , 7.5 , 7. , 8. , 9. , 8.78125 ,
|
|
9.78125 ,10.78125 ,10. ,11. ,12. ,10.28125 ,11.28125 ,12.28125 ,
|
|
5.875 ,6.875 , 7.875 , 7.09375 , 8.09375 , 9.09375 , 8.875 , 9.875 ,
|
|
10.875 ,10.375 , 11.375 , 12.375 , 11.875 , 12.875 , 13.875 , 13.65625,
|
|
14.65625 ,15.65625, 14.875 , 15.875 , 16.875 , 15.15625, 16.15625, 17.15625,
|
|
13. ,14. , 15. , 14.21875, 15.21875, 16.21875, 16. , 17. ,
|
|
18. ,17.5 , 18.5 , 19.5 , 19. , 20. , 21. , 20.78125,
|
|
21.78125 ,22.78125, 22. , 23. , 24. , 22.28125, 23.28125, 24.28125,
|
|
19. ,20. , 21. , 20.21875, 21.21875, 22.21875, 22. , 23. ,
|
|
24. ,23.5 , 24.5 , 25.5 , 25. , 26. , 27. , 26.78125,
|
|
27.78125 ,28.78125, 28. , 29. , 30. , 28.28125, 29.28125, 30.28125,
|
|
25. ,26. , 27. , 26.21875, 27.21875, 28.21875, 28. , 29. ,
|
|
30. ,29.5 , 30.5 , 31.5 , 31. , 32. , 33. , 32.78125,
|
|
33.78125 ,34.78125, 34. , 35. , 36. , 34.28125, 35.28125, 36.28125,
|
|
32.125 ,33.125 , 34.125 , 33.34375, 34.34375, 35.34375, 35.125 , 36.125 ,
|
|
37.125 ,36.625 , 37.625 , 38.625 , 38.125 , 39.125 , 40.125 , 39.90625,
|
|
40.90625 ,41.90625, 41.125 , 42.125 , 43.125 , 41.40625, 42.40625, 43.40625,
|
|
37. ,38. , 39. , 38.21875, 39.21875, 40.21875, 40. , 41. ,
|
|
42. ,41.5 , 42.5 , 43.5 , 43. , 44. , 45. , 44.78125,
|
|
45.78125 ,46.78125, 46. , 47. , 48. , 46.28125, 47.28125, 48.28125,
|
|
38.125 ,39.125 , 40.125 , 39.34375, 40.34375, 41.34375, 41.125 , 42.125 ,
|
|
43.125 ,42.625 , 43.625 , 44.625 , 44.125 , 45.125 , 46.125 , 45.90625,
|
|
46.90625 ,47.90625, 47.125 , 48.125 , 49.125 , 47.40625, 48.40625, 49.40625,
|
|
});
|
|
input.linspace(1);
|
|
auto size = NDArrayFactory::create<int>({8, 8});
|
|
nd4j::ops::resize_bicubic op;
|
|
auto results = op.execute({&input, &size}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Resized to 8x8");
|
|
// expected.printBuffer("Expect for 8x8");
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {7, 7, 1}, {
|
|
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
|
|
8, 9.1, 10., 11, 12.9, 13.1, 14,
|
|
15, 16., 17., 18, 19, 20., 21,
|
|
22, 23., 24., 25, 26, 27, 28,
|
|
30, 31, 32, 33, 34., 35, 36,
|
|
37, 38, 39, 40, 41., 42, 43,
|
|
44, 45, 46, 47, 48., 49, 50
|
|
});
|
|
|
|
NDArray expected = NDArrayFactory::create<double>('c', {30, 30, 1}, {
|
|
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
|
|
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
|
|
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
|
|
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
|
|
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
|
|
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
|
|
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
|
|
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
|
|
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
|
|
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
|
|
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
|
|
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
|
|
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
|
|
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
|
|
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
|
|
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
|
|
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
|
|
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
|
|
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
|
|
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
|
|
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
|
|
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
|
|
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
|
|
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
|
|
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
|
|
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
|
|
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
|
|
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
|
|
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
|
|
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
|
|
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
|
|
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
|
|
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
|
|
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
|
|
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
|
|
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
|
|
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
|
|
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
|
|
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
|
|
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
|
|
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
|
|
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
|
|
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
|
|
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
|
|
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
|
|
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
|
|
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
|
|
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
|
|
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
|
|
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
|
|
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
|
|
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
|
|
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
|
|
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
|
|
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
|
|
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
|
|
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
|
|
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
|
|
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
|
|
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
|
|
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
|
|
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
|
|
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
|
|
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
|
|
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
|
|
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
|
|
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
|
|
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
|
|
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
|
|
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
|
|
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
|
|
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
|
|
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
|
|
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
|
|
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
|
|
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
|
|
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
|
|
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
|
|
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
|
|
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
|
|
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
|
|
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
|
|
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
|
|
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
|
|
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
|
|
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
|
|
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
|
|
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
|
|
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
|
|
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
|
|
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
|
|
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
|
|
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
|
|
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
|
|
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
|
|
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
|
|
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
|
|
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
|
|
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
|
|
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
|
|
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
|
|
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
|
|
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
|
|
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
|
|
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
|
|
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
|
|
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
|
|
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
|
|
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
|
|
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
|
|
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
|
|
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
|
|
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
|
|
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
|
|
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
|
|
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
|
|
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
|
|
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
|
|
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
|
|
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
|
|
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
|
|
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
|
|
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
|
|
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
|
|
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
|
|
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
|
|
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
|
|
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
|
|
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
|
|
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
|
|
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
|
|
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
|
|
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
|
|
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
|
|
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
|
|
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
|
|
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
|
|
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
|
|
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
|
|
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
|
|
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
|
|
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
|
|
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
|
|
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
|
|
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
|
|
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
|
|
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
|
|
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
|
|
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
|
|
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
|
|
|
|
auto size = NDArrayFactory::create<int>({30, 30});
|
|
nd4j::ops::resize_bicubic op;
|
|
auto results = op.execute({&input, &size}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Resized to 30x30");
|
|
// expected.printBuffer("Expect for 30x30");
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
|
|
|
functions::summarystats::SummaryStatsData<double> var1;
|
|
functions::summarystats::SummaryStatsData<double> var2;
|
|
var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5;
|
|
|
|
functions::summarystats::SummaryStatsData<double>* arr = new functions::summarystats::SummaryStatsData<double>[2];
|
|
arr[0] = var1;
|
|
arr[1] = var2;
|
|
arr[0] = arr[1];
|
|
|
|
functions::summarystats::SummaryStatsData<double> var3(var1);
|
|
|
|
ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5);
|
|
ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5);
|
|
ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0);
|
|
|
|
delete []arr;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
|
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 ,
|
|
155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto dLdp = results->at(0);
|
|
auto dLdw = results->at(1);
|
|
auto dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
|
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
|
NDArray dLdwExp('c', {}, {4515.84});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96,
|
|
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
|
|
NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208,
|
|
-2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {}, {0.});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152,
|
|
-1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304});
|
|
NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992,
|
|
-6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 });
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.p(0, 0.);
|
|
weights.p(1, 0.);
|
|
weights.p(2, 0.);
|
|
weights.p(3, 0.);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48,
|
|
-0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296,
|
|
6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,1}, {188.16});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576,
|
|
-0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552,
|
|
7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
weights.t<double>(3) = 0.;
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
|
|
NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
|
|
nd4j::ops::mean_sqerr_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) {
|
|
auto x = NDArrayFactory::create<float>('c', {4}, {0, 1, 2, 3});
|
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
|
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
|
|
nd4j::ops::squaredsubtract op;
|
|
auto result = op.execute({&x, &y}, {}, {});
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) {
|
|
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
|
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
|
|
nd4j::ops::squaredsubtract op;
|
|
auto result = op.execute({&x, &y}, {}, {});
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) {
|
|
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
|
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
|
|
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
|
|
nd4j::ops::squaredsubtract_bp op;
|
|
auto result = op.execute({&x, &y, &eps}, {}, {});
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
|
delete result;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
|
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52,
|
|
12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto dLdp = results->at(0);
|
|
auto dLdw = results->at(1);
|
|
auto dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
|
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
|
NDArray dLdwExp('c', {}, {288.});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,
|
|
-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167});
|
|
NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04,
|
|
0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {}, {0.});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,
|
|
-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05});
|
|
NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 ,
|
|
-0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.p(0, 0.);
|
|
weights.p(1, 0.);
|
|
weights.p(2, 0.);
|
|
weights.p(3, 0.);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,
|
|
-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48,
|
|
0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,1}, {12.});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,
|
|
-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576,
|
|
0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152});
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
weights.t<double>(3) = 0.;
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.,
|
|
-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167});
|
|
NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 });
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
|
|
nd4j::ops::absolute_difference_loss_grad op;
|
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_1) {
|
|
|
|
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
|
|
x.linspace(1);
|
|
y.linspace(1);
|
|
exp.linspace(2,2);
|
|
nd4j::ops::add op;
|
|
auto results = op.execute({&x, &y}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto res = results->at(0);
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_2) {
|
|
|
|
NDArray x = NDArrayFactory::create<float16>('c', {2,3,4});
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
NDArray exp = NDArrayFactory::create<float16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
|
|
x.linspace(1);
|
|
y.linspace(1);
|
|
exp.linspace(2,2);
|
|
nd4j::ops::add op;
|
|
auto results = op.execute({&x, &y}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto res = results->at(0);
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_3) {
|
|
|
|
NDArray x('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
NDArray y('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
NDArray exp('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
|
|
x.linspace(1);
|
|
y.linspace(1);
|
|
exp.linspace(2,2);
|
|
nd4j::ops::add op;
|
|
auto results = op.execute({&x, &y}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto res = results->at(0);
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548,
|
|
-6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074,
|
|
-4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113});
|
|
NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18,
|
|
-0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
|
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
|
NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669});
|
|
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
|
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
|
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
|
NDArray dLdwExp('c', {}, {-91.52109});
|
|
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
|
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171,
|
|
-0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539,
|
|
0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881});
|
|
NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105,
|
|
-0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights(nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {}, {0.});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805,
|
|
-0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246,
|
|
0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258});
|
|
NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126,
|
|
-0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294});
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.p(0, 0.);
|
|
weights.p(1, 0.);
|
|
weights.p(2, 0.);
|
|
weights.p(3, 0.);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585,
|
|
-0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 ,
|
|
-0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 });
|
|
NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525,
|
|
-0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225});
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,1}, {-3.81338});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902,
|
|
-0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629});
|
|
NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944,
|
|
-0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196});
|
|
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063,
|
|
-0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147});
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
weights.t<double>(3) = 0.;
|
|
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
|
|
|
|
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
|
-0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715});
|
|
NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,});
|
|
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.,
|
|
-0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245});
|
|
logits.linspace(-0.08, 0.04);
|
|
labels.linspace(1);
|
|
weights.assign(0.5);
|
|
weights.t<double>(0) = 0.;
|
|
weights.t<double>(1) = 0.;
|
|
weights.t<double>(2) = 0.;
|
|
|
|
nd4j::ops::sigm_cross_entropy_loss_grad op;
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_4) {
|
|
|
|
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
|
|
x.linspace(1);
|
|
y.linspace(1);
|
|
exp.linspace(2,2);
|
|
nd4j::ops::add op;
|
|
auto results = op.execute({&x, &y}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto res = results->at(0);
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_5) {
|
|
|
|
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
|
|
x.linspace(2, 2);
|
|
y.linspace(1);
|
|
exp.linspace(1);
|
|
nd4j::ops::subtract op;
|
|
auto results = op.execute({&x, &y}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto res = results->at(0);
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_6) {
|
|
|
|
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
|
|
NDArray y = NDArrayFactory::create<double>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
|
|
|
|
x.linspace(2, 2);
|
|
y.linspace(1);
|
|
exp.linspace(1);
|
|
nd4j::ops::subtract op;
|
|
auto results = op.execute({&x, &y}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto res = results->at(0);
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) {
|
|
|
|
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
|
|
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326});
|
|
NDArray dLdwExp('c', {2}, {1.36729, 1.40729});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
|
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
|
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
|
NDArray dLdwExp('c', {1}, {1.38629});
|
|
|
|
logits = 2.;
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
|
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
|
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
|
NDArray dLdwExp('c', {}, {1.38629});
|
|
|
|
logits = 2.;
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
|
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
|
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
|
|
NDArray dLdwExp('c', {}, {0.});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
weights = 0.5;
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
|
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
|
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326});
|
|
NDArray dLdwExp('c', {1}, {1.36729});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
weights = 0.5;
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) {
|
|
|
|
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
|
|
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951});
|
|
NDArray dLdwExp('c', {2}, {-0.014000, 0.014000});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) {
|
|
|
|
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, nd4j::DataType::INT32);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,3}, {0.5, 0., 1.5});
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945,
|
|
0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945});
|
|
NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {3});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) {
|
|
|
|
NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,
|
|
0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,
|
|
0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, nd4j::DataType::INT32);
|
|
|
|
NDArray logits('c', {2,3,4,5}, nd4j::DataType::DOUBLE);
|
|
NDArray weights('c', {1,1,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335,
|
|
0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399,
|
|
0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866,
|
|
0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799,
|
|
0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901,
|
|
0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832,
|
|
0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768,
|
|
0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866,
|
|
0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901});
|
|
|
|
NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005});
|
|
logits.linspace(-0.08, 0.04);
|
|
weights.assign(0.5);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_grad op;
|
|
|
|
auto results = op.execute({&logits, &weights, &labels}, {0.}, {2});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
auto *dLdw = results->at(1);
|
|
auto *dLdl = results->at(2);
|
|
|
|
// dLdp->printIndexedBuffer();
|
|
|
|
// ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
// ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
delete results;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) {
|
|
|
|
NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0});
|
|
auto sumDiff = labels.reduceAlongDims(reduce::Sum, {1}, true);
|
|
|
|
NDArray numOfNonZero(sumDiff.getShapeInfo(), nd4j::DataType::INT64, false);
|
|
numOfNonZero.assign(1);
|
|
sumDiff.applyPairwiseTransform(pairwise::SafeDivide, &numOfNonZero, &sumDiff, nullptr);
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) {
|
|
|
|
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0});
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519,
|
|
0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519});
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) {
|
|
|
|
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0});
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785,
|
|
0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785});
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) {
|
|
|
|
NDArray labels('c', {2,3}, {1,0,0, 0,1,1});
|
|
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004});
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
|
|
|
|
NDArray labels('c', {2,1}, {1,1});
|
|
NDArray logits('c', {2,1}, {-0.04, 0.04});
|
|
|
|
NDArray dLdpExp('c', {2,1}, {0., 0.});
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
|
|
|
|
NDArray labels('c', {2,1}, {1,0});
|
|
NDArray logits('c', {2,1}, {-0.04, 0.04});
|
|
|
|
NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999});
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
|
|
|
|
NDArray labels('c', {1,2}, {1,1});
|
|
NDArray logits('c', {1,2}, {-0.04, 0.04});
|
|
|
|
NDArray dLdpExp('c', {1,2}, {0, 0});
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
|
|
|
|
NDArray labels('c', {2}, {0,1});
|
|
NDArray logits('c', {2}, {-0.04, 0.04});
|
|
|
|
NDArray dLdpExp('c', {2}, {0.48001, -0.48001});
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
|
|
|
|
NDArray labels('c', {1}, {1});
|
|
NDArray logits('c', {1}, {0.04});
|
|
|
|
NDArray dLdpExp('c', {1}, {0});
|
|
|
|
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&logits, &labels}, {}, {0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) {
|
|
|
|
NDArray x('c', {3,4,5}, nd4j::DataType::DOUBLE);
|
|
NDArray y('c', {1,1,1}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdp('c', {3,4,5}, nd4j::DataType::DOUBLE);
|
|
NDArray dLdpExp('c', {3,4,5}, nd4j::DataType::DOUBLE);
|
|
|
|
x.assign(1.0);//linspace(0.1, 0.1);
|
|
y.assign(1.0);
|
|
dLdp.assign(1.0);
|
|
dLdpExp.assign(1.0);
|
|
nd4j::ops::multiply_bp op;
|
|
|
|
auto results = op.execute({&x, &y, &dLdp}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdo = results->at(0);
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdo));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdo));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) {
|
|
|
|
NDArray labels('c', {2}, {2,1}, nd4j::DataType::INT64);
|
|
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717});
|
|
|
|
logits.linspace(0.1, 0.1);
|
|
|
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&labels, &logits}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
|
|
|
|
NDArray labels('c', {2}, {0,1}, nd4j::DataType::INT64);
|
|
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717});
|
|
|
|
logits.linspace(-0.1, 0.1);
|
|
|
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&labels, &logits}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
|
|
|
|
NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
|
|
NDArray logits('c', {2}, {-0.2, 0.3});
|
|
|
|
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
|
|
|
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&labels, &logits}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
|
|
|
|
NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, nd4j::DataType::INT64);
|
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865,
|
|
0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865});
|
|
logits.linspace(-0.5, 0.1);
|
|
|
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&labels, &logits}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
|
|
|
|
NDArray labels('c', {1,1}, {0}, nd4j::DataType::INT64);
|
|
NDArray logits('c', {1,1,2}, {-0.3,0.2});
|
|
|
|
NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246});
|
|
|
|
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
|
|
|
auto results = op.execute({&labels, &logits}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
delete results;
|
|
}
|
|
|
|
|
|
|