2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
// Created by raver on 8/4/2018.
|
|
|
|
//
|
|
|
|
|
|
|
|
#include "testlayers.h"
|
|
|
|
#include <ops/declarable/CustomOperations.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <array/NDArray.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <ops/ops.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <helpers/GradCheck.h>
|
2020-02-28 09:37:26 +01:00
|
|
|
#include <helpers/MmulHelper.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
using namespace sd;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
|
|
|
class DeclarableOpsTests11 : public testing::Test {
|
|
|
|
public:
|
|
|
|
|
|
|
|
DeclarableOpsTests11() {
|
|
|
|
printf("\n");
|
|
|
|
fflush(stdout);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
|
|
|
|
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});
|
|
|
|
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::listdiff op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto result = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
|
|
|
|
|
|
delete result;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test1) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
|
|
|
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444,
|
|
|
|
-0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
|
|
|
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test2) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
|
|
|
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{-227.77286});
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
|
|
|
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test4) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
// dLdw->printIndexedBuffer();
|
|
|
|
// dLdw->printShapeInfo();
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test5) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308,
|
|
|
|
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541,
|
|
|
|
0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334,
|
|
|
|
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test6) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test8) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769,
|
|
|
|
-2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 ,
|
|
|
|
1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004,
|
|
|
|
-0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.p(0, 0.);
|
|
|
|
weights.p(1, 0.);
|
|
|
|
weights.p(2, 0.);
|
|
|
|
weights.p(3, 0.);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test9) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154,
|
|
|
|
-1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727,
|
|
|
|
-0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167,
|
|
|
|
-0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {1,1}, std::vector<double>{-9.49054});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test11) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846,
|
|
|
|
-1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072,
|
|
|
|
-0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 ,
|
|
|
|
-0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945});
|
|
|
|
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
weights.t<double>(3) = 0.;
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
|
|
|
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
|
|
|
|
NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
|
|
|
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::log_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
2019-11-20 20:11:04 +01:00
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
|
|
|
|
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray input = NDArrayFactory::create<float>('c', {1, 7, 7, 1}, {
|
|
|
|
1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f,
|
|
|
|
8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f,
|
|
|
|
15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
|
|
|
|
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,
|
|
|
|
30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
|
|
|
|
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f,
|
|
|
|
44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f
|
2019-11-20 20:11:04 +01:00
|
|
|
});
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 30, 30, 1}, {
|
|
|
|
1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, 2.3283265f,
|
|
|
|
2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, 3.5441515f, 3.7380352f,
|
|
|
|
3.948995f, 4.248106f, 4.5073795f, 4.6843743f, 4.8572845f, 5.104302f,
|
|
|
|
5.3869915f, 5.581401f, 5.7539616f, 5.974285f, 6.272836f, 6.5204263f,
|
|
|
|
6.718899f, 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f,
|
|
|
|
2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, 3.5745337f,
|
|
|
|
3.78931f, 3.9656973f, 4.186417f, 4.5046535f, 4.740569f, 4.9217057f,
|
|
|
|
5.133866f, 5.459533f, 5.7744613f, 6.0197873f, 6.254011f, 6.535633f,
|
|
|
|
6.8097296f, 6.9607787f, 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f,
|
|
|
|
7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f,
|
|
|
|
3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, 4.9556503f,
|
|
|
|
5.160583f, 5.3258467f, 5.535462f, 5.84216f, 6.058749f, 6.223753f,
|
|
|
|
6.437597f, 6.797369f, 7.1836042f, 7.5164022f, 7.8290343f, 8.154773f,
|
|
|
|
8.417635f, 8.512958f, 8.5521f, 8.649708f, 8.87788f, 9.108794f,
|
|
|
|
9.320926f, 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f,
|
|
|
|
5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, 6.6005697f,
|
|
|
|
6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, 7.644651f, 7.794562f,
|
|
|
|
8.009684f, 8.400473f, 8.851847f, 9.26469f, 9.649218f, 10.015648f,
|
|
|
|
10.268647f, 10.313368f, 10.2843275f, 10.319379f, 10.512033f, 10.734956f,
|
|
|
|
10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f,
|
|
|
|
7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, 8.638187f,
|
|
|
|
8.827649f, 8.976217f, 9.168955f, 9.45726f, 9.6442375f, 9.784517f,
|
|
|
|
9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f,
|
|
|
|
12.420712f, 12.4374485f, 12.370511f, 12.371386f, 12.545973f, 12.766424f,
|
|
|
|
12.992249f, 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f,
|
|
|
|
9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, 10.78547f,
|
|
|
|
10.974367f, 11.123442f, 11.31637f, 11.603645f, 11.790616f, 11.930889f,
|
|
|
|
12.144082f, 12.546447f, 13.024898f, 13.4723f, 13.889232f, 14.276275f,
|
|
|
|
14.528972f, 14.555555f, 14.50145f, 14.515459f, 14.700572f, 14.927055f,
|
|
|
|
15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f,
|
|
|
|
10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f,
|
|
|
|
12.43254f, 12.588294f, 12.787534f, 13.079956f, 13.27752f, 13.426631f,
|
|
|
|
13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f,
|
|
|
|
15.81343f, 15.881828f, 15.883522f, 15.950411f, 16.16933f, 16.40794f,
|
|
|
|
16.636436f, 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f,
|
|
|
|
12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f,
|
|
|
|
13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f,
|
|
|
|
14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f,
|
|
|
|
16.8099f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f,
|
|
|
|
17.892765f, 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f,
|
|
|
|
13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f,
|
|
|
|
15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f,
|
|
|
|
16.50487f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f,
|
|
|
|
18.118288f, 18.296928f, 18.4461f, 18.651634f, 18.956806f, 19.22382f,
|
|
|
|
19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f,
|
|
|
|
15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.14954f,
|
|
|
|
17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f,
|
|
|
|
18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f,
|
|
|
|
20.13878f, 20.35177f, 20.546844f, 20.795671f, 21.128067f, 21.404358f,
|
|
|
|
21.626736f, 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f,
|
|
|
|
17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, 18.738056f,
|
|
|
|
18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f,
|
|
|
|
20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f,
|
|
|
|
21.689209f, 21.911621f, 22.119276f, 22.37999f, 22.71991f, 22.998823f,
|
|
|
|
23.22097f, 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f,
|
|
|
|
18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, 19.952137f,
|
|
|
|
20.164913f, 20.345781f, 20.569134f, 20.88284f, 21.12133f, 21.30459f,
|
|
|
|
21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.37289f, 22.626648f,
|
|
|
|
22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f,
|
|
|
|
24.431519f, 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f,
|
|
|
|
20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f,
|
|
|
|
21.589132f, 21.768297f, 21.99003f, 22.302366f, 22.538124f, 22.719105f,
|
|
|
|
22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.83589f, 24.096842f,
|
|
|
|
24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.62813f,
|
|
|
|
25.850672f, 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f,
|
|
|
|
22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, 23.57634f,
|
|
|
|
23.787327f, 23.96576f, 24.186796f, 24.498543f, 24.733124f, 24.913122f,
|
|
|
|
25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f,
|
|
|
|
26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f,
|
|
|
|
28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f,
|
|
|
|
24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, 25.642756f,
|
|
|
|
25.853743f, 26.032173f, 26.25321f, 26.564959f, 26.79954f, 26.97954f,
|
|
|
|
27.181242f, 27.47763f, 27.74168f, 27.929441f, 28.117207f, 28.381254f,
|
|
|
|
28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f,
|
|
|
|
30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f,
|
|
|
|
26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, 27.213314f,
|
|
|
|
27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.3701f, 28.550098f,
|
|
|
|
28.7518f, 29.04819f, 29.312237f, 29.5f, 29.687763f, 29.951813f,
|
|
|
|
30.2482f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f,
|
|
|
|
31.683659f, 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f,
|
|
|
|
27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f,
|
|
|
|
28.994865f, 29.173294f, 29.39433f, 29.70608f, 29.940659f, 30.120655f,
|
|
|
|
30.32236f, 30.618746f, 30.882797f, 31.070557f, 31.25832f, 31.522371f,
|
|
|
|
31.818754f, 32.02046f, 32.20046f, 32.43504f, 32.758415f, 33.031567f,
|
|
|
|
33.25422f, 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f,
|
|
|
|
29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, 30.85029f,
|
|
|
|
31.061283f, 31.239712f, 31.46075f, 31.7725f, 32.00708f, 32.187077f,
|
|
|
|
32.38878f, 32.685165f, 32.949215f, 33.13698f, 33.32474f, 33.58879f,
|
|
|
|
33.885178f, 34.086884f, 34.26688f, 34.501457f, 34.824837f, 35.09799f,
|
|
|
|
35.320637f, 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f,
|
|
|
|
31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, 33.046757f,
|
|
|
|
33.257744f, 33.436176f, 33.657207f, 33.96896f, 34.203537f, 34.383537f,
|
|
|
|
34.58524f, 34.88163f, 35.145676f, 35.33344f, 35.521206f, 35.785255f,
|
|
|
|
36.081642f, 36.28334f, 36.46334f, 36.69792f, 37.021297f, 37.294453f,
|
|
|
|
37.517097f, 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f,
|
|
|
|
33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, 34.466957f,
|
|
|
|
34.677948f, 34.856377f, 35.077415f, 35.38916f, 35.623745f, 35.803745f,
|
|
|
|
36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f,
|
|
|
|
37.50184f, 37.703545f, 37.883545f, 38.118122f, 38.4415f, 38.714653f,
|
|
|
|
38.9373f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f,
|
|
|
|
34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, 35.6781f,
|
|
|
|
35.889088f, 36.067516f, 36.28855f, 36.6003f, 36.834885f, 37.014877f,
|
|
|
|
37.216583f, 37.51297f, 37.77702f, 37.964783f, 38.152546f, 38.416595f,
|
|
|
|
38.71298f, 38.914684f, 39.094685f, 39.32926f, 39.652645f, 39.925793f,
|
|
|
|
40.14844f, 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f,
|
|
|
|
36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, 37.271378f,
|
|
|
|
37.48237f, 37.6608f, 37.881836f, 38.19359f, 38.42817f, 38.608162f,
|
|
|
|
38.809868f, 39.10625f, 39.3703f, 39.558064f, 39.74583f, 40.00988f,
|
|
|
|
40.306267f, 40.50797f, 40.68797f, 40.92255f, 41.245926f, 41.519077f,
|
|
|
|
41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f,
|
|
|
|
38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.52832f,
|
|
|
|
39.739307f, 39.917736f, 40.138775f, 40.45052f, 40.685104f, 40.865097f,
|
|
|
|
41.066803f, 41.36319f, 41.627243f, 41.815002f, 42.002766f, 42.26682f,
|
|
|
|
42.5632f, 42.764908f, 42.944904f, 43.179485f, 43.50286f, 43.776016f,
|
|
|
|
43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f,
|
|
|
|
40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, 41.440395f,
|
|
|
|
41.651382f, 41.82982f, 42.050854f, 42.3626f, 42.597183f, 42.77718f,
|
|
|
|
42.97888f, 43.27527f, 43.53932f, 43.72708f, 43.914845f, 44.178894f,
|
|
|
|
44.47528f, 44.676983f, 44.856983f, 45.09156f, 45.41494f, 45.68809f,
|
|
|
|
45.91074f, 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f,
|
|
|
|
41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, 42.998936f,
|
|
|
|
43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.15572f, 44.335716f,
|
|
|
|
44.53742f, 44.833805f, 45.09786f, 45.285614f, 45.473377f, 45.737427f,
|
|
|
|
46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.24663f,
|
|
|
|
47.469276f, 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f,
|
|
|
|
43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f,
|
|
|
|
44.539444f, 44.717873f, 44.93891f, 45.25066f, 45.48524f, 45.665237f,
|
|
|
|
45.86694f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f,
|
|
|
|
47.363342f, 47.56505f, 47.74505f, 47.979626f, 48.302998f, 48.576153f,
|
|
|
|
48.798798f, 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f,
|
|
|
|
44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.51718f,
|
|
|
|
45.72817f, 45.9066f, 46.12764f, 46.439384f, 46.673965f, 46.853966f,
|
|
|
|
47.055668f, 47.352055f, 47.6161f, 47.803867f, 47.99163f, 48.25568f,
|
|
|
|
48.552063f, 48.75377f, 48.933773f, 49.16835f, 49.491726f, 49.764877f,
|
|
|
|
49.987526f, 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f,
|
|
|
|
44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, 45.98499f,
|
|
|
|
46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f,
|
|
|
|
47.523476f, 47.819862f, 48.08391f, 48.27168f, 48.459446f, 48.72349f,
|
|
|
|
49.019882f, 49.22158f, 49.401585f, 49.63616f, 49.959538f, 50.232693f,
|
|
|
|
50.455338f, 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f,
|
|
|
|
44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, 45.82328f,
|
|
|
|
46.03427f, 46.2127f, 46.433743f, 46.74549f, 46.98007f, 47.160065f,
|
|
|
|
47.36177f, 47.658157f, 47.922207f, 48.10997f, 48.297733f, 48.561783f,
|
|
|
|
48.858166f, 49.059875f, 49.239872f, 49.47445f, 49.79783f, 50.07098f,
|
|
|
|
50.293625f, 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f,
|
|
|
|
44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, 45.43256f,
|
|
|
|
45.643543f, 45.82198f, 46.04302f, 46.354763f, 46.589344f, 46.76934f,
|
|
|
|
46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.17105f,
|
|
|
|
48.467438f, 48.66914f, 48.849144f, 49.08372f, 49.4071f, 49.680256f,
|
|
|
|
49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
auto size = NDArrayFactory::create<int>({30, 30});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 30x30");
|
|
|
|
// expected.printBuffer("Expect for 30x30");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 4, 3});
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {2, 10, 8, 3}, {
|
|
|
|
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
|
|
|
|
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
|
|
|
|
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
|
|
|
|
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
|
|
|
|
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
|
|
|
|
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
|
|
|
|
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
|
|
|
|
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
|
|
|
|
19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f,
|
|
|
|
23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f,
|
|
|
|
28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f,
|
|
|
|
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
|
|
|
|
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
|
|
|
|
34.281250f, 35.281250f, 36.281250f, 31.000000f, 32.000000f, 33.000000f, 32.218750f, 33.218750f, 34.218750f,
|
|
|
|
34.000000f, 35.000000f, 36.000000f, 35.500000f, 36.500000f, 37.500000f, 37.000000f, 38.000000f, 39.000000f,
|
|
|
|
38.781250f, 39.781250f, 40.781250f, 40.000000f, 41.000000f, 42.000000f, 40.281250f, 41.281250f, 42.281250f,
|
|
|
|
37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f,
|
|
|
|
41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f,
|
|
|
|
46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 44.125000f, 45.125000f, 46.125000f,
|
|
|
|
45.343750f, 46.343750f, 47.343750f, 47.125000f, 48.125000f, 49.125000f, 48.625000f, 49.625000f, 50.625000f,
|
|
|
|
50.125000f, 51.125000f, 52.125000f, 51.906250f, 52.906250f, 53.906250f, 53.125000f, 54.125000f, 55.125000f,
|
|
|
|
53.406250f, 54.406250f, 55.406250f, 49.000000f, 50.000000f, 51.000000f, 50.218750f, 51.218750f, 52.218750f,
|
|
|
|
52.000000f, 53.000000f, 54.000000f, 53.500000f, 54.500000f, 55.500000f, 55.000000f, 56.000000f, 57.000000f,
|
|
|
|
56.781250f, 57.781250f, 58.781250f, 58.000000f, 59.000000f, 60.000000f, 58.281250f, 59.281250f, 60.281250f,
|
|
|
|
50.125000f, 51.125000f, 52.125000f, 51.343750f, 52.343750f, 53.343750f, 53.125000f, 54.125000f, 55.125000f,
|
|
|
|
54.625000f, 55.625000f, 56.625000f, 56.125000f, 57.125000f, 58.125000f, 57.906250f, 58.906250f, 59.906250f,
|
|
|
|
59.125000f, 60.125000f, 61.125000f, 59.406250f, 60.406250f, 61.406250f, 61.000000f, 62.000000f, 63.000000f,
|
|
|
|
62.218750f, 63.218750f, 64.218750f, 64.000000f, 65.000000f, 66.000000f, 65.500000f, 66.500000f, 67.500000f,
|
|
|
|
67.000000f, 68.000000f, 69.000000f, 68.781250f, 69.781250f, 70.781250f, 70.000000f, 71.000000f, 72.000000f,
|
|
|
|
70.281250f, 71.281250f, 72.281250f, 65.875000f, 66.875000f, 67.875000f, 67.093750f, 68.093750f, 69.093750f,
|
|
|
|
68.875000f, 69.875000f, 70.875000f, 70.375000f, 71.375000f, 72.375000f, 71.875000f, 72.875000f, 73.875000f,
|
|
|
|
73.656250f, 74.656250f, 75.656250f, 74.875000f, 75.875000f, 76.875000f, 75.156250f, 76.156250f, 77.156250f,
|
|
|
|
73.000000f, 74.000000f, 75.000000f, 74.218750f, 75.218750f, 76.218750f, 76.000000f, 77.000000f, 78.000000f,
|
|
|
|
77.500000f, 78.500000f, 79.500000f, 79.000000f, 80.000000f, 81.000000f, 80.781250f, 81.781250f, 82.781250f,
|
|
|
|
82.000000f, 83.000000f, 84.000000f, 82.281250f, 83.281250f, 84.281250f, 79.000000f, 80.000000f, 81.000000f,
|
|
|
|
80.218750f, 81.218750f, 82.218750f, 82.000000f, 83.000000f, 84.000000f, 83.500000f, 84.500000f, 85.500000f,
|
|
|
|
85.000000f, 86.000000f, 87.000000f, 86.781250f, 87.781250f, 88.781250f, 88.000000f, 89.000000f, 90.000000f,
|
|
|
|
88.281250f, 89.281250f, 90.281250f, 85.000000f, 86.000000f, 87.000000f, 86.218750f, 87.218750f, 88.218750f,
|
|
|
|
88.000000f, 89.000000f, 90.000000f, 89.500000f, 90.500000f, 91.500000f, 91.000000f, 92.000000f, 93.000000f,
|
|
|
|
92.781250f, 93.781250f, 94.781250f, 94.000000f, 95.000000f, 96.000000f, 94.281250f, 95.281250f, 96.281250f,
|
|
|
|
91.000000f, 92.000000f, 93.000000f, 92.218750f, 93.218750f, 94.218750f, 94.000000f, 95.000000f, 96.000000f,
|
|
|
|
95.500000f, 96.500000f, 97.500000f, 97.000000f, 98.000000f, 99.000000f, 98.781250f, 99.781250f, 100.781250f,
|
|
|
|
100.000000f, 101.000000f, 102.000000f, 100.281250f, 101.281250f, 102.281250f, 97.000000f, 98.000000f,
|
|
|
|
99.000000f, 98.218750f, 99.218750f, 100.218750f, 100.000000f, 101.000000f, 102.000000f, 101.500000f,
|
|
|
|
102.500000f, 103.500000f, 103.000000f, 104.000000f, 105.000000f, 104.781250f, 105.781250f, 106.781250f,
|
|
|
|
106.000000f, 107.000000f, 108.000000f, 106.281250f, 107.281250f, 108.281250f, 104.125000f, 105.125000f,
|
|
|
|
106.125000f, 105.343750f, 106.343750f, 107.343750f, 107.125000f, 108.125000f, 109.125000f, 108.625000f,
|
|
|
|
109.625000f, 110.625000f, 110.125000f, 111.125000f, 112.125000f, 111.906250f, 112.906250f, 113.906250f,
|
|
|
|
113.125000f, 114.125000f, 115.125000f, 113.406250f, 114.406250f, 115.406250f, 109.000000f, 110.000000f,
|
|
|
|
111.000000f, 110.218750f, 111.218750f, 112.218750f, 112.000000f, 113.000000f, 114.000000f, 113.500000f,
|
|
|
|
114.500000f, 115.500000f, 115.000000f, 116.000000f, 117.000000f, 116.781250f, 117.781250f, 118.781250f,
|
|
|
|
118.000000f, 119.000000f, 120.000000f, 118.281250f, 119.281250f, 120.281250f, 110.125000f, 111.125000f,
|
|
|
|
112.125000f, 111.343750f, 112.343750f, 113.343750f, 113.125000f, 114.125000f, 115.125000f, 114.625000f,
|
|
|
|
115.625000f, 116.625000f, 116.125000f, 117.125000f, 118.125000f, 117.906250f, 118.906250f, 119.906250f,
|
|
|
|
119.125000f, 120.125000f, 121.125000f, 119.406250f, 120.406250f, 121.406250f
|
|
|
|
}); //input = 1.f;
|
2019-11-20 20:11:04 +01:00
|
|
|
input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({10, 8});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 10x8");
|
|
|
|
// expected.printBuffer("Expect for 10x8");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 4}, {
|
|
|
|
1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, 4.625000f, 5.625000f, 5.000000f,
|
|
|
|
6.000000f, 7.000000f, 8.000000f, 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f,
|
|
|
|
11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, 5.875000f, 6.875000f, 7.875000f,
|
|
|
|
8.875000f, 7.500000f, 8.500000f, 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f,
|
|
|
|
12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, 15.875000f, 16.875000f, 14.250000f,
|
|
|
|
15.250000f, 16.250000f, 17.250000f, 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f,
|
|
|
|
16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, 19.375000f, 20.375000f, 21.375000f,
|
|
|
|
22.375000f, 21.000000f, 22.000000f, 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f,
|
|
|
|
20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, 23.750000f, 24.750000f, 24.125000f,
|
|
|
|
25.125000f, 26.125000f, 27.125000f, 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f,
|
|
|
|
30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, 25.000000f, 26.000000f, 27.000000f,
|
|
|
|
28.000000f, 26.625000f, 27.625000f, 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f,
|
|
|
|
31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, 35.000000f, 36.000000f, 33.375000f,
|
|
|
|
34.375000f, 35.375000f, 36.375000f, 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f,
|
|
|
|
29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, 32.500000f, 33.500000f, 34.500000f,
|
|
|
|
35.500000f, 34.125000f, 35.125000f, 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f
|
2019-11-20 20:11:04 +01:00
|
|
|
});
|
|
|
|
input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 6x6");
|
|
|
|
// expected.printBuffer("Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 4, 3});
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 8, 3}, {
|
|
|
|
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
|
|
|
|
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
|
|
|
|
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
|
|
|
|
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
|
|
|
|
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
|
|
|
|
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
|
|
|
|
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
|
|
|
|
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
|
|
|
|
20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, 23.125000f, 24.125000f, 25.125000f,
|
|
|
|
24.625000f, 25.625000f, 26.625000f, 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f,
|
|
|
|
29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, 25.000000f, 26.000000f, 27.000000f,
|
|
|
|
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
|
|
|
|
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
|
|
|
|
34.281250f, 35.281250f, 36.281250f, 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f,
|
|
|
|
29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, 32.125000f, 33.125000f, 34.125000f,
|
|
|
|
33.906250f, 34.906250f, 35.906250f, 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f
|
|
|
|
});
|
2019-11-20 20:11:04 +01:00
|
|
|
input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 8});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 6x8");
|
|
|
|
// expected.printBuffer("Expect for 6x8");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 4, 4, 3});
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 8, 3}, {
|
|
|
|
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
|
|
|
|
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
|
|
|
|
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
|
|
|
|
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
|
|
|
|
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
|
|
|
|
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
|
|
|
|
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
|
|
|
|
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
|
|
|
|
19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f,
|
|
|
|
23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f,
|
|
|
|
28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f,
|
|
|
|
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
|
|
|
|
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
|
|
|
|
34.281250f, 35.281250f, 36.281250f, 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, 35.343750f,
|
|
|
|
35.125000f, 36.125000f, 37.125000f, 36.625000f, 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f,
|
|
|
|
39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, 43.125000f, 41.406250f, 42.406250f, 43.406250f,
|
|
|
|
37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f,
|
|
|
|
41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f,
|
|
|
|
46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, 40.125000f,
|
|
|
|
39.343750f, 40.343750f, 41.343750f, 41.125000f, 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f,
|
|
|
|
44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, 47.906250f, 47.125000f, 48.125000f, 49.125000f,
|
|
|
|
47.406250f, 48.406250f, 49.406250f,
|
|
|
|
});
|
2019-11-20 20:11:04 +01:00
|
|
|
input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({8, 8});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 8x8");
|
|
|
|
// expected.printBuffer("Expect for 8x8");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
|
|
|
|
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray input = NDArrayFactory::create<float>('c', {7, 7, 1}, {
|
|
|
|
1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f,
|
|
|
|
8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f,
|
|
|
|
15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
|
|
|
|
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,
|
|
|
|
30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
|
|
|
|
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f,
|
|
|
|
44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f
|
2019-11-20 20:11:04 +01:00
|
|
|
});
|
|
|
|
|
2019-12-06 16:58:37 +01:00
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {30, 30, 1}, {
|
|
|
|
1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, 2.550918f, 2.736061f, 2.965541f,
|
|
|
|
3.292965f, 3.544151f, 3.738035f, 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f,
|
|
|
|
5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, 6.718899f, 6.887104f, 7.039068f,
|
|
|
|
7.099216f, 7.078424f, 7.028189f, 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f,
|
|
|
|
3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, 5.133866f, 5.459533f, 5.774461f,
|
|
|
|
6.019787f, 6.254011f, 6.535633f, 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f,
|
|
|
|
7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, 3.628684f, 3.830573f, 4.056959f,
|
|
|
|
4.321157f, 4.636486f, 4.955650f, 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f,
|
|
|
|
6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, 8.417635f, 8.512958f, 8.552100f,
|
|
|
|
8.649708f, 8.877880f, 9.108794f, 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f,
|
|
|
|
5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, 6.796207f, 6.951142f, 7.150400f,
|
|
|
|
7.446143f, 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f,
|
|
|
|
10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, 11.154507f, 11.315369f,
|
|
|
|
11.374779f, 11.354242f, 11.304622f, 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f,
|
|
|
|
8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, 9.999621f, 10.407702f, 10.896234f,
|
|
|
|
11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f,
|
|
|
|
12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, 9.493208f, 9.692467f, 9.916944f,
|
|
|
|
10.176801f, 10.482199f, 10.785470f, 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f,
|
|
|
|
12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, 14.501450f,
|
|
|
|
14.515459f, 14.700572f, 14.927055f, 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f,
|
|
|
|
10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, 12.432540f, 12.588294f, 12.787534f,
|
|
|
|
13.079956f, 13.277520f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f,
|
|
|
|
15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, 16.636436f, 16.842583f, 17.010887f,
|
|
|
|
17.073630f, 17.051940f, 16.999537f, 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f,
|
|
|
|
13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f,
|
|
|
|
15.924977f, 16.213829f, 16.532364f, 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f,
|
|
|
|
17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, 13.766397f, 13.947391f, 14.148263f,
|
|
|
|
14.386917f, 14.681246f, 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f,
|
|
|
|
16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, 18.446100f,
|
|
|
|
18.651634f, 18.956806f, 19.223820f, 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f,
|
|
|
|
15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, 17.361883f, 17.542162f, 17.764957f,
|
|
|
|
18.078188f, 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f,
|
|
|
|
20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, 21.815500f, 21.985610f,
|
|
|
|
22.052843f, 22.029604f, 21.973448f, 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f,
|
|
|
|
18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f,
|
|
|
|
20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f,
|
|
|
|
23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, 18.746353f, 18.922657f, 19.117487f,
|
|
|
|
19.350685f, 19.642070f, 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f,
|
|
|
|
21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, 22.926834f, 23.143423f, 23.343302f,
|
|
|
|
23.596668f, 23.931936f, 24.209232f, 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f,
|
|
|
|
20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, 21.589132f, 21.768297f, 21.990030f,
|
|
|
|
22.302366f, 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f,
|
|
|
|
24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, 25.850672f, 26.040140f, 26.210072f,
|
|
|
|
26.277063f, 26.253906f, 26.197956f, 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f,
|
|
|
|
23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f,
|
|
|
|
25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f,
|
|
|
|
28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, 24.429443f, 24.607670f, 24.804970f,
|
|
|
|
25.040410f, 25.333065f, 25.642756f, 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f,
|
|
|
|
27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, 29.059345f,
|
|
|
|
29.293922f, 29.617298f, 29.890451f, 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f,
|
|
|
|
26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, 27.424305f, 27.602734f, 27.823772f,
|
|
|
|
28.135519f, 28.370100f, 28.550098f, 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f,
|
|
|
|
30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, 31.873592f, 32.043407f,
|
|
|
|
32.110240f, 32.087135f, 32.031320f, 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f,
|
|
|
|
28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, 30.322360f, 30.618746f, 30.882797f,
|
|
|
|
31.070557f, 31.258320f, 31.522371f, 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f,
|
|
|
|
33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, 29.636976f, 29.815207f, 30.012500f,
|
|
|
|
30.247944f, 30.540600f, 30.850290f, 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f,
|
|
|
|
32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, 33.885178f, 34.086884f, 34.266880f,
|
|
|
|
34.501457f, 34.824837f, 35.097990f, 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f,
|
|
|
|
31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, 33.257744f, 33.436176f, 33.657207f,
|
|
|
|
33.968960f, 34.203537f, 34.383537f, 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f,
|
|
|
|
36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, 37.517097f, 37.707027f, 37.876846f,
|
|
|
|
37.943680f, 37.920578f, 37.864758f, 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f,
|
|
|
|
34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f,
|
|
|
|
36.753647f, 36.941406f, 37.205456f, 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f,
|
|
|
|
38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, 34.464783f, 34.643010f, 34.840305f,
|
|
|
|
35.075752f, 35.368404f, 35.678100f, 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f,
|
|
|
|
37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, 38.712980f, 38.914684f, 39.094685f,
|
|
|
|
39.329260f, 39.652645f, 39.925793f, 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f,
|
|
|
|
36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, 37.482370f, 37.660800f, 37.881836f,
|
|
|
|
38.193590f, 38.428170f, 38.608162f, 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f,
|
|
|
|
40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, 41.741722f, 41.931652f, 42.101475f,
|
|
|
|
42.168304f, 42.145203f, 42.089386f, 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f,
|
|
|
|
39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, 41.066803f, 41.363190f, 41.627243f,
|
|
|
|
41.815002f, 42.002766f, 42.266820f, 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f,
|
|
|
|
43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, 40.227080f, 40.405310f, 40.602608f,
|
|
|
|
40.838050f, 41.130707f, 41.440395f, 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f,
|
|
|
|
42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, 44.475280f, 44.676983f, 44.856983f,
|
|
|
|
45.091560f, 45.414940f, 45.688090f, 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f,
|
|
|
|
41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, 43.209923f, 43.388355f, 43.609394f,
|
|
|
|
43.921143f, 44.155720f, 44.335716f, 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f,
|
|
|
|
46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, 47.469276f, 47.659210f, 47.829030f,
|
|
|
|
47.895855f, 47.872753f, 47.816940f, 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f,
|
|
|
|
44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, 45.866940f, 46.163326f, 46.427376f,
|
|
|
|
46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f,
|
|
|
|
48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, 44.303867f, 44.482094f, 44.679394f,
|
|
|
|
44.914833f, 45.207493f, 45.517180f, 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f,
|
|
|
|
47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, 48.552063f, 48.753770f, 48.933773f,
|
|
|
|
49.168350f, 49.491726f, 49.764877f, 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f,
|
|
|
|
44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, 46.195976f, 46.374413f, 46.595448f,
|
|
|
|
46.907196f, 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f,
|
|
|
|
49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, 50.455338f, 50.645270f, 50.815090f,
|
|
|
|
50.881920f, 50.858818f, 50.803000f, 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f,
|
|
|
|
46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, 47.361770f, 47.658157f, 47.922207f,
|
|
|
|
48.109970f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f,
|
|
|
|
50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, 44.219246f, 44.397472f, 44.594772f,
|
|
|
|
44.830210f, 45.122868f, 45.432560f, 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f,
|
|
|
|
46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, 48.467438f, 48.669140f, 48.849144f,
|
|
|
|
49.083720f, 49.407100f, 49.680256f, 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f
|
|
|
|
});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
auto size = NDArrayFactory::create<int>({30, 30});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2019-11-20 20:11:04 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 30x30");
|
|
|
|
// expected.printBuffer("Expect for 30x30");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
2019-12-19 10:15:48 +01:00
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
|
|
|
|
0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511,
|
|
|
|
0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613,
|
|
|
|
0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989,
|
|
|
|
0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312,
|
|
|
|
0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {2, 9, 9, 1}, {
|
|
|
|
0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f,
|
|
|
|
0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f,
|
|
|
|
0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f,
|
|
|
|
0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f,
|
|
|
|
0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f,
|
|
|
|
0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f,
|
|
|
|
0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f,
|
|
|
|
0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f,
|
|
|
|
0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f,
|
|
|
|
0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f,
|
|
|
|
0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f,
|
|
|
|
0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f,
|
|
|
|
0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f,
|
|
|
|
0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f,
|
|
|
|
0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f,
|
|
|
|
0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f,
|
|
|
|
0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f,
|
|
|
|
0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f,
|
|
|
|
0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f,
|
|
|
|
0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f,
|
|
|
|
0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f,
|
|
|
|
0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f,
|
|
|
|
0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f,
|
|
|
|
0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f,
|
|
|
|
0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f,
|
|
|
|
0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f,
|
|
|
|
0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f,
|
|
|
|
0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f,
|
|
|
|
0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f,
|
|
|
|
0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f,
|
|
|
|
0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f,
|
|
|
|
0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f,
|
|
|
|
0.9139262f, 0.92068815f
|
|
|
|
});
|
|
|
|
auto size = NDArrayFactory::create<int>({9, 9});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2019-12-19 10:15:48 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 9x9");
|
|
|
|
// expected.printBuffer("Expect for 9x9");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
|
|
|
|
0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338,
|
|
|
|
0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859,
|
|
|
|
0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232,
|
|
|
|
0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986,
|
|
|
|
0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547,
|
|
|
|
0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666,
|
|
|
|
0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496,
|
|
|
|
0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752,
|
|
|
|
0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968,
|
|
|
|
0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859});
|
|
|
|
|
|
|
|
|
|
|
|
auto testData = NDArrayFactory::create<float>('c', {2,9,9,1}, {
|
2019-12-20 15:56:28 +01:00
|
|
|
0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f,
|
2019-12-19 10:15:48 +01:00
|
|
|
0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f,
|
|
|
|
0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f,
|
|
|
|
0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f,
|
|
|
|
0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f,
|
|
|
|
0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f,
|
|
|
|
0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f,
|
|
|
|
|
|
|
|
0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f,
|
|
|
|
0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f,
|
|
|
|
0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f,
|
|
|
|
0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f,
|
|
|
|
0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f,
|
|
|
|
0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f,
|
|
|
|
|
|
|
|
0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f,
|
|
|
|
0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f,
|
|
|
|
0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f,
|
|
|
|
0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f,
|
|
|
|
0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto size = NDArrayFactory::create<int>({9, 9});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_bicubic op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {}, {true, false});
|
2019-12-19 10:15:48 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Resized to 9x9");
|
2019-12-20 15:56:28 +01:00
|
|
|
// testData.printBuffer("Expect for 9x9");
|
2019-12-19 10:15:48 +01:00
|
|
|
ASSERT_TRUE(testData.isSameShape(result));
|
|
|
|
ASSERT_TRUE(testData.equalsTo(result));
|
2020-01-22 08:46:33 +01:00
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 4}, {
|
|
|
|
1.f, 2.f, 3.f, 4.f,
|
|
|
|
1.f, 2.f, 3.f, 4.f,
|
|
|
|
5.f, 6.f, 7.f, 8.f,
|
|
|
|
5.f, 6.f, 7.f, 8.f,
|
|
|
|
9.f, 10.f, 11.f, 12.f,
|
|
|
|
9.f, 10.f, 11.f, 12.f,
|
|
|
|
|
|
|
|
1.f, 2.f, 3.f, 4.f,
|
|
|
|
1.f, 2.f, 3.f, 4.f,
|
|
|
|
5.f, 6.f, 7.f, 8.f,
|
|
|
|
5.f, 6.f, 7.f, 8.f,
|
|
|
|
9.f, 10.f, 11.f, 12.f,
|
|
|
|
9.f, 10.f, 11.f, 12.f,
|
|
|
|
|
|
|
|
13.f, 14.f, 15.f, 16.f,
|
|
|
|
13.f, 14.f, 15.f, 16.f,
|
|
|
|
17.f, 18.f, 19.f, 20.f,
|
|
|
|
17.f, 18.f, 19.f, 20.f,
|
|
|
|
21.f, 22.f, 23.f, 24.f,
|
|
|
|
21.f, 22.f, 23.f, 24.f,
|
|
|
|
|
|
|
|
13.f, 14.f, 15.f, 16.f,
|
|
|
|
13.f, 14.f, 15.f, 16.f,
|
|
|
|
17.f, 18.f, 19.f, 20.f,
|
|
|
|
17.f, 18.f, 19.f, 20.f,
|
|
|
|
21.f, 22.f, 23.f, 24.f,
|
|
|
|
21.f, 22.f, 23.f, 24.f,
|
|
|
|
|
|
|
|
25.f, 26.f, 27.f, 28.f,
|
|
|
|
25.f, 26.f, 27.f, 28.f,
|
|
|
|
29.f, 30.f, 31.f, 32.f,
|
|
|
|
29.f, 30.f, 31.f, 32.f,
|
|
|
|
33.f, 34.f, 35.f, 36.f,
|
|
|
|
33.f, 34.f, 35.f, 36.f,
|
|
|
|
|
|
|
|
25.f, 26.f, 27.f, 28.f,
|
|
|
|
25.f, 26.f, 27.f, 28.f,
|
|
|
|
29.f, 30.f, 31.f, 32.f,
|
|
|
|
29.f, 30.f, 31.f, 32.f,
|
|
|
|
33.f, 34.f, 35.f, 36.f,
|
|
|
|
33.f, 34.f, 35.f, 36.f });
|
|
|
|
input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 1});
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 1}, {
|
|
|
|
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
|
|
|
|
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
|
|
|
|
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
|
|
|
|
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
|
|
|
|
7.f, 7.f, 8.f, 8.f, 9.f, 9.f,
|
|
|
|
7.f, 7.f, 8.f, 8.f, 9.f, 9.f
|
|
|
|
});
|
|
|
|
input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 3}, {
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
|
|
|
|
});
|
|
|
|
input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 3, 3}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27
|
|
|
|
});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 3}, {
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
|
|
|
|
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
|
|
|
|
});
|
|
|
|
//input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 3}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27
|
|
|
|
});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 3}, {
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
|
|
|
|
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
|
|
|
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
|
|
|
|
});
|
|
|
|
//input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 1}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9,
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9
|
|
|
|
});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 1}, {
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f,
|
|
|
|
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
|
|
|
|
});
|
|
|
|
//input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {}, {true});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
2020-01-24 18:55:25 +01:00
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
2020-01-22 08:46:33 +01:00
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 1}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9,
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9
|
|
|
|
});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 1}, {
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f,
|
|
|
|
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
|
|
|
|
});
|
|
|
|
//input.linspace(1);
|
|
|
|
// auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input}, {}, {6, 6}, {true});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
2020-01-24 18:55:25 +01:00
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
2020-01-22 08:46:33 +01:00
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 3, 3, 1}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9
|
|
|
|
});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 1}, {
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
|
|
|
|
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
|
|
|
|
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
|
|
|
|
});
|
|
|
|
//input.linspace(1);
|
|
|
|
// auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input}, {}, {6, 6}, {true});
|
2020-01-22 08:46:33 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
2020-01-24 18:55:25 +01:00
|
|
|
// result->printBuffer("Area Resized to 6x6");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
|
|
|
|
});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {
|
|
|
|
1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000
|
|
|
|
|
|
|
|
});
|
|
|
|
//input.linspace(1);
|
|
|
|
auto size = NDArrayFactory::create<int>({10, 10});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {});
|
2020-01-24 18:55:25 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 10x10");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
|
|
|
|
});
|
|
|
|
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {
|
|
|
|
1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000
|
|
|
|
|
|
|
|
});
|
|
|
|
//input.linspace(1);
|
|
|
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input}, {}, {10, 10});
|
2020-01-24 18:55:25 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 10x10");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
2020-01-22 08:46:33 +01:00
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
2020-01-24 18:55:25 +01:00
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
|
|
|
|
});
|
|
|
|
|
|
|
|
// NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 9, 4}, {
|
|
|
|
// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.33333
|
|
|
|
//
|
|
|
|
// });
|
|
|
|
//input.linspace(1);
|
|
|
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input}, {}, {6, 9});
|
2020-01-24 18:55:25 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 6x9");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
// ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
// ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
|
|
|
|
});
|
|
|
|
|
|
|
|
// NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 9, 4}, {
|
|
|
|
// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.33333
|
|
|
|
//
|
|
|
|
// });
|
|
|
|
//input.linspace(1);
|
|
|
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input}, {}, {10, 15});
|
2020-01-24 18:55:25 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 6x9");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
// ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
// ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
|
|
|
|
});
|
|
|
|
|
|
|
|
// NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 8, 4}, {
|
|
|
|
// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.33333
|
|
|
|
//
|
|
|
|
// });
|
|
|
|
//input.linspace(1);
|
|
|
|
//auto size = NDArrayFactory::create<int>({10, 10});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&input}, {}, {9, 9});
|
2020-01-24 18:55:25 +01:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
|
|
|
|
// result->printBuffer("Area Resized to 8x8");
|
|
|
|
// expected.printBuffer("Area Expect for 6x6");
|
|
|
|
// ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
// ASSERT_TRUE(expected.equalsTo(result));
|
2019-11-20 20:11:04 +01:00
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
2020-02-12 17:02:42 +01:00
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
|
|
|
|
});
|
|
|
|
auto size = NDArrayFactory::create<int>({8, 7});
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
|
|
|
|
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
|
|
|
|
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
|
|
|
|
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
|
|
|
|
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
|
|
|
|
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
|
|
|
|
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f ,
|
|
|
|
21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f ,
|
|
|
|
25.f
|
|
|
|
}); //input.linspace(1);
|
|
|
|
// auto size = NDArrayFactory::create<int>({6, 6});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-02-12 17:02:42 +01:00
|
|
|
auto results = op.evaluate({&input, &size}, {}, {false});
|
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Area Resized to 8x7");
|
|
|
|
// expected.printBuffer("Area Expect for 8x7");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
|
|
|
|
|
|
|
|
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
|
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
|
|
|
|
});
|
|
|
|
//auto size = NDArrayFactory::create<int>({8, 7});
|
|
|
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
|
|
|
|
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
|
|
|
|
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
|
|
|
|
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
|
|
|
|
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
|
|
|
|
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
|
|
|
|
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f ,
|
|
|
|
22.999998f , 23.800001f , 24.399984f , 25.f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::resize_area op;
|
2020-02-12 17:02:42 +01:00
|
|
|
auto results = op.evaluate({&input}, {}, {8, 7}, {false});
|
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
NDArray* result = results->at(0);
|
|
|
|
// result->printBuffer("Area Resized to 8x7");
|
|
|
|
// expected.printBuffer("Area Expect for 8x7");
|
|
|
|
ASSERT_TRUE(expected.isSameShape(result));
|
|
|
|
ASSERT_TRUE(expected.equalsTo(result));
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
functions::summarystats::SummaryStatsData<double> var1;
|
|
|
|
functions::summarystats::SummaryStatsData<double> var2;
|
2019-09-11 19:12:09 +02:00
|
|
|
var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
functions::summarystats::SummaryStatsData<double>* arr = new functions::summarystats::SummaryStatsData<double>[2];
|
|
|
|
arr[0] = var1;
|
|
|
|
arr[1] = var2;
|
|
|
|
arr[0] = arr[1];
|
|
|
|
|
|
|
|
functions::summarystats::SummaryStatsData<double> var3(var1);
|
|
|
|
|
|
|
|
ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5);
|
|
|
|
ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5);
|
|
|
|
ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0);
|
|
|
|
|
|
|
|
delete []arr;
|
|
|
|
}
|
|
|
|
|
2020-02-04 06:59:11 +01:00
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_1) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 1}, {
|
|
|
|
2.f, 4.f, 3.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {
|
|
|
|
7.625f, 3.25f, 5.f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-04 06:59:11 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printIndexedBuffer("Solve of 3x3");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_2) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
|
|
|
|
1.f, 1.f, 1.f, 1.f,
|
|
|
|
0.f, 1.f, 1.f, 0.f,
|
|
|
|
0.f, 0.f, 2.f, 1.f,
|
|
|
|
0.f, 0.f, 0.f, 3.f,
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {4, 1}, {
|
|
|
|
2.f, 4.f, 2.f, 4.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {
|
|
|
|
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-04 06:59:11 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printIndexedBuffer("Solve 4x4");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_3) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2, 4, 4}, {
|
|
|
|
1.f, 1.f, 1.f, 1.f,
|
|
|
|
0.f, 1.f, 1.f, 0.f,
|
|
|
|
0.f, 0.f, 2.f, 1.f,
|
|
|
|
0.f, 0.f, 0.f, 3.f,
|
|
|
|
|
|
|
|
3.f, 0.f, 0.f, 0.f,
|
|
|
|
2.f, 1.f, 0.f, 0.f,
|
|
|
|
1.f, 0.f, 1.f, 0.f,
|
|
|
|
1.f, 1.f, 1.f, 1.f
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {2, 4, 1}, {
|
|
|
|
2.f, 4.f, 2.f, 4.f,
|
|
|
|
4.f, 2.f, 4.f, 2.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 4, 1}, {
|
|
|
|
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f,
|
|
|
|
1.333333f, -0.6666667f, 2.6666667f, -1.3333333f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-04 06:59:11 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printIndexedBuffer("Solve 4x4");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
|
|
|
0.7271f, 0.1804f, 0.5056f, 0.8925f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f, 0.4838f,
|
|
|
|
0.6433f, 0.6041f, 0.6501f, 0.7612f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f,
|
|
|
|
// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f
|
|
|
|
1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f,
|
|
|
|
0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-04 06:59:11 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printBuffer("4 Solve 4x4");
|
|
|
|
// exp.printBuffer("4 Expec 4x4");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
2020-02-06 19:06:50 +01:00
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_1) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
|
|
|
0.7271f, 0.1804f, 0.5056f, 0.8925f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f,
|
|
|
|
0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {true});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printBuffer("4 Solve 4x4");
|
|
|
|
// exp.printBuffer("4 Expec 4x4");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_2) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f,
|
|
|
|
0.2309f, 0.7271f, 0.1804f,
|
|
|
|
0.5056f, 0.8925f, 0.5461f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f,
|
|
|
|
0.4838f, 0.6433f, 0.6041f,
|
|
|
|
0.6501f, 0.7612f, 0.7605f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.99088347f, 1.1917052f, 1.2642528f,
|
|
|
|
0.35071516f, 0.50630623f, 0.42935497f,
|
|
|
|
-0.30013534f, -0.53690606f, -0.47959247f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::triangular_solve op;
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {true, false});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printBuffer("4_2 Triangular_Solve 3x3");
|
|
|
|
// exp.printBuffer("4_2 Triangular_Expec 3x3");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_3) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f,
|
|
|
|
0.2309f, 0.7271f, 0.1804f,
|
|
|
|
0.5056f, 0.8925f, 0.5461f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f,
|
|
|
|
0.4838f, 0.6433f, 0.6041f,
|
|
|
|
0.6501f, 0.7612f, 0.7605f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.45400196f, 0.53174824f, 0.62064564f,
|
|
|
|
-0.79585856f, -0.82621557f, -0.87855506f,
|
|
|
|
1.1904413f, 1.3938838f, 1.3926021f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::triangular_solve op;
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {true, true});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printBuffer("4_3 Triangular_Solve 3x3");
|
|
|
|
// exp.printBuffer("4_3 Triangular_Expec 3x3");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_4) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f,
|
|
|
|
0.2309f, 0.7271f, 0.1804f,
|
|
|
|
0.5056f, 0.8925f, 0.5461f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f,
|
|
|
|
0.4838f, 0.6433f, 0.6041f,
|
|
|
|
0.6501f, 0.7612f, 0.7605f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.8959121f, 1.6109066f, 1.7501404f,
|
|
|
|
0.49000582f, 0.66842675f, 0.5577021f,
|
|
|
|
-0.4398522f, -1.1899745f, -1.1392052f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {false});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printBuffer("4_4 Solve 3x3");
|
|
|
|
// exp.printBuffer("4_4 Expec 3x3");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_5) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f,
|
|
|
|
0.2309f, 0.7271f, 0.1804f,
|
|
|
|
0.5056f, 0.8925f, 0.5461f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f,
|
|
|
|
0.4838f, 0.6433f, 0.6041f,
|
|
|
|
0.6501f, 0.7612f, 0.7605f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
1.5504692f, 1.8953944f, 2.2765768f,
|
|
|
|
0.03399149f, 0.2883001f, 0.5377323f,
|
|
|
|
-0.8774802f, -1.2155888f, -1.8049058f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {true, true});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printBuffer("4_5 Solve 3x3");
|
|
|
|
// exp.printBuffer("4_5 Expec 3x3");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_6) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f,
|
|
|
|
0.2309f, 0.7271f, 0.1804f,
|
|
|
|
0.5056f, 0.8925f, 0.5461f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f,
|
|
|
|
0.4838f, 0.6433f, 0.6041f,
|
|
|
|
0.6501f, 0.7612f, 0.7605f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.99088347f, 1.1917052f, 1.2642528f,
|
|
|
|
-0.426483f, -0.42840624f, -0.5622601f,
|
|
|
|
0.01692283f, -0.04538865f, -0.09868701f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::triangular_solve op;
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {false, true});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
2020-02-28 09:37:26 +01:00
|
|
|
// z->printBuffer("4_6 Solve 3x3");
|
|
|
|
// exp.printBuffer("4_6 Expec 3x3");
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_7) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
// 0.7788f, 0.2309f, 0.5056f,
|
|
|
|
// 0.8012f, 0.7271f, 0.8925f,
|
|
|
|
// 0.7244f, 0.1804f, 0.5461f
|
|
|
|
|
|
|
|
0.7788f, 0.2309f, 0.5056f,
|
|
|
|
0.8012f, 0.7271f, 0.8925f,
|
|
|
|
0.7244f, 0.1804f, 0.5461f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f,
|
|
|
|
0.4838f, 0.6433f, 0.6041f,
|
|
|
|
0.6501f, 0.7612f, 0.7605f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.99088347f, 1.1917052f, 1.2642528f,
|
|
|
|
-0.426483f, -0.42840624f, -0.5622601f,
|
|
|
|
0.01692283f, -0.04538865f, -0.09868701f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::triangular_solve op;
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {true, false});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
2020-02-28 09:37:26 +01:00
|
|
|
// z->printBuffer("4_7 Solve 3x3");
|
|
|
|
// exp.printBuffer("4_7 Expec 3x3");
|
2020-02-06 19:06:50 +01:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
2020-02-04 06:59:11 +01:00
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7788f, 0.8012f, 0.7244f,
|
|
|
|
0.2309f, 0.7271f, 0.1804f,
|
|
|
|
0.5056f, 0.8925f, 0.5461f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
0.7717f, 0.9281f, 0.9846f,
|
|
|
|
0.4838f, 0.6433f, 0.6041f,
|
|
|
|
0.6501f, 0.7612f, 0.7605f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
|
|
|
1.5504692f, 1.8953944f, 2.2765768f,
|
|
|
|
0.03399149f, 0.2883001f, 0.5377323f,
|
|
|
|
-0.8774802f, -1.2155888f, -1.8049058f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::solve op;
|
2020-02-04 06:59:11 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {true});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
2020-02-28 09:37:26 +01:00
|
|
|
// z->printBuffer("4 Solve 4x4");
|
|
|
|
// exp.printBuffer("4 Expec 4x4");
|
2020-02-04 06:59:11 +01:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
2020-02-28 09:37:26 +01:00
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, SolveLS_Test_1) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<double>('c', {2,2, 2}, {
|
|
|
|
1.f, 2.f, 3.f, 4.f,
|
|
|
|
5.f, 6.f, 7.f, 8.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<double>('c', {2, 2, 1}, {
|
|
|
|
3.f, 7.f, 11.f, 15.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 1}, {
|
|
|
|
0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::lstsq op;
|
2020-02-28 09:37:26 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {0.5}, {}, {true});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printIndexedBuffer("LS Solve 2x2");
|
|
|
|
// exp.printIndexedBuffer("LS Expec 2x2");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z, 1.e-4));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, SolveLS_Test_2) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
|
|
|
|
1.f, 2.f, 3.f, 4.f,
|
|
|
|
5.f, 6.f, 7.f, 8.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto b = NDArrayFactory::create<float>('c', {2, 2, 1}, {
|
|
|
|
3.f, 7.f, 11.f, 15.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 1}, {
|
|
|
|
0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::lstsq op;
|
2020-02-28 09:37:26 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a, &b}, {0.5}, {}, {true});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printIndexedBuffer("2LS Solve 2x2");
|
|
|
|
// exp.printIndexedBuffer("2LS Expec 2x2");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z, 1.e-4));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
|
|
|
|
10.f, 14.f,
|
|
|
|
14.f, 20.f,
|
|
|
|
|
|
|
|
74.f, 86.f,
|
|
|
|
86.f, 100.f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
3.1622777f, 0.f, 4.427189f, 0.6324552f,
|
|
|
|
8.602325f, 0.f, 9.997296f, 0.23252854f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::cholesky op;
|
2020-02-28 09:37:26 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
z->printIndexedBuffer("L matrix is");
|
|
|
|
exp.printIndexedBuffer("L expected is");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) {
|
|
|
|
|
|
|
|
auto a = NDArrayFactory::create<float>('c', {2,2, 2}, {
|
|
|
|
10.5f, 14.f,
|
|
|
|
14.f, 20.5f,
|
|
|
|
|
|
|
|
74.5f, 86.f,
|
|
|
|
86.f, 100.5f
|
|
|
|
});
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
|
|
|
3.2403703f, 0.f, 4.3204937f, 1.3540066f,
|
|
|
|
8.631338f, 0.f, 9.963693f, 1.1067207f
|
|
|
|
});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::cholesky op;
|
2020-02-28 09:37:26 +01:00
|
|
|
|
|
|
|
auto res = op.evaluate({&a});
|
|
|
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
|
|
|
auto z = res->at(0);
|
|
|
|
|
|
|
|
// z->printIndexedBuffer("L matrix is");
|
|
|
|
// exp.printIndexedBuffer("L expected is");
|
|
|
|
MmulHelper::matmul(z, z, &exp, false, true);
|
|
|
|
ASSERT_TRUE(exp.equalsTo(a));
|
|
|
|
delete res;
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
|
|
|
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 ,
|
|
|
|
155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto dLdw = results->at(1);
|
|
|
|
auto dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
|
|
|
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{4515.84});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96,
|
|
|
|
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208,
|
|
|
|
-2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152,
|
|
|
|
-1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992,
|
|
|
|
-6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 });
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.p(0, 0.);
|
|
|
|
weights.p(1, 0.);
|
|
|
|
weights.p(2, 0.);
|
|
|
|
weights.p(3, 0.);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48,
|
|
|
|
-0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296,
|
|
|
|
6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {1,1}, std::vector<double>{188.16});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576,
|
|
|
|
-0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552,
|
|
|
|
7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
weights.t<double>(3) = 0.;
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
|
|
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
|
|
|
|
NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::mean_sqerr_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {4}, {0, 1, 2, 3});
|
|
|
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::squaredsubtract op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto result = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
|
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
|
|
|
|
|
|
|
delete result;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
|
|
|
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::squaredsubtract op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto result = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
|
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
|
|
|
delete result;
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
|
|
|
|
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
|
|
|
|
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::squaredsubtract_bp op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto result = op.evaluate({&x, &y, &eps}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
|
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
|
|
|
delete result;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
|
|
|
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52,
|
|
|
|
12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto dLdw = results->at(1);
|
|
|
|
auto dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
|
|
|
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{288.});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,
|
|
|
|
-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04,
|
|
|
|
0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,
|
|
|
|
-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 ,
|
|
|
|
-0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.p(0, 0.);
|
|
|
|
weights.p(1, 0.);
|
|
|
|
weights.p(2, 0.);
|
|
|
|
weights.p(3, 0.);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,
|
|
|
|
-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48,
|
|
|
|
0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {1,1}, std::vector<double>{12.});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,
|
|
|
|
-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576,
|
|
|
|
0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152});
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
weights.t<double>(3) = 0.;
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.,
|
|
|
|
-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167});
|
|
|
|
NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 });
|
|
|
|
|
|
|
|
predictions.linspace(0.04, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::absolute_difference_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_1) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
|
|
|
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
x.linspace(1);
|
|
|
|
y.linspace(1);
|
|
|
|
exp.linspace(2,2);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::add op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
auto res = results->at(0);
|
|
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_2) {
|
|
|
|
|
|
|
|
NDArray x = NDArrayFactory::create<float16>('c', {2,3,4});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
|
|
|
NDArray exp = NDArrayFactory::create<float16>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
x.linspace(1);
|
|
|
|
y.linspace(1);
|
|
|
|
exp.linspace(2,2);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::add op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto res = results->at(0);
|
|
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_3) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray x('c', {2,3,4}, sd::DataType::BFLOAT16);
|
|
|
|
NDArray y('c', {2,3,4}, sd::DataType::BFLOAT16);
|
|
|
|
NDArray exp('c', {2,3,4}, sd::DataType::BFLOAT16);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
x.linspace(1);
|
|
|
|
y.linspace(1);
|
|
|
|
exp.linspace(2,2);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::add op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto res = results->at(0);
|
|
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548,
|
|
|
|
-6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074,
|
|
|
|
-4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18,
|
|
|
|
-0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
|
|
|
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
|
|
|
NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
|
|
|
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
|
|
|
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{-91.52109});
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
|
|
|
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171,
|
|
|
|
-0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539,
|
|
|
|
0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105,
|
|
|
|
-0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights(sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805,
|
|
|
|
-0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246,
|
|
|
|
0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126,
|
|
|
|
-0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.p(0, 0.);
|
|
|
|
weights.p(1, 0.);
|
|
|
|
weights.p(2, 0.);
|
|
|
|
weights.p(3, 0.);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585,
|
|
|
|
-0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 ,
|
|
|
|
-0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 });
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525,
|
|
|
|
-0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {1,1}, std::vector<double>{-3.81338});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
auto *dLdw = results->at(1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902,
|
|
|
|
-0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629});
|
|
|
|
NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944,
|
|
|
|
-0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063,
|
|
|
|
-0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
weights.t<double>(3) = 0.;
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
|
|
|
|
-0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715});
|
|
|
|
NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,});
|
|
|
|
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.,
|
|
|
|
-0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245});
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
labels.linspace(1);
|
|
|
|
weights.assign(0.5);
|
|
|
|
weights.t<double>(0) = 0.;
|
|
|
|
weights.t<double>(1) = 0.;
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.t<double>(2) = 0.;
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sigm_cross_entropy_loss_grad op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
|
|
|
|
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_4) {
|
|
|
|
|
|
|
|
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
|
|
|
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
x.linspace(1);
|
|
|
|
y.linspace(1);
|
|
|
|
exp.linspace(2,2);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::add op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto res = results->at(0);
|
|
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_5) {
|
|
|
|
|
|
|
|
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
|
|
|
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
x.linspace(2, 2);
|
|
|
|
y.linspace(1);
|
|
|
|
exp.linspace(1);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::subtract op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto res = results->at(0);
|
|
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, BFloat16_Test_6) {
|
|
|
|
|
|
|
|
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray y = NDArrayFactory::create<double>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
|
|
|
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
x.linspace(2, 2);
|
|
|
|
y.linspace(1);
|
|
|
|
exp.linspace(1);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::subtract op;
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&x, &y}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto res = results->at(0);
|
|
|
|
ASSERT_TRUE(res->equalsTo(exp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32);
|
|
|
|
NDArray logits('c', {2,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326});
|
|
|
|
NDArray dLdwExp('c', {2}, {1.36729, 1.40729});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
2019-09-11 19:12:09 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32);
|
|
|
|
NDArray logits('c', {4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {1}, std::vector<double>{1.38629});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
logits = 2.;
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
2019-09-11 19:12:09 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete results;
|
2019-09-11 19:12:09 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32);
|
|
|
|
NDArray logits('c', {4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {}, std::vector<double>{0}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{1.38629});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
logits = 2.;
|
2019-09-11 19:12:09 +02:00
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
2019-09-11 19:12:09 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete results;
|
2019-09-11 19:12:09 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32);
|
|
|
|
NDArray logits('c', {4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {}, std::vector<double>{0}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights = 0.5;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
2019-09-11 19:12:09 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete results;
|
2019-09-11 19:12:09 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32);
|
|
|
|
NDArray logits('c', {4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326});
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdwExp('c', {1}, std::vector<double>{1.36729});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
2019-09-11 19:12:09 +02:00
|
|
|
weights = 0.5;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
2019-09-11 19:12:09 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete results;
|
2019-09-11 19:12:09 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32);
|
|
|
|
NDArray logits('c', {2,4}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {2}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951});
|
|
|
|
NDArray dLdwExp('c', {2}, {-0.014000, 0.014000});
|
|
|
|
|
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
weights.assign(0.5);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, sd::DataType::INT32);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray weights('c', {1,3}, {0.5, 0., 1.5});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945,
|
|
|
|
0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945});
|
|
|
|
NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365});
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
logits.linspace(-0.08, 0.04);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
2019-09-11 19:12:09 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) {
|
|
|
|
|
|
|
|
NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,
|
|
|
|
0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,
|
2020-03-02 10:49:41 +01:00
|
|
|
0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, sd::DataType::INT32);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray logits('c', {2,3,4,5}, sd::DataType::DOUBLE);
|
|
|
|
NDArray weights('c', {1,1,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335,
|
2019-09-11 19:12:09 +02:00
|
|
|
0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399,
|
|
|
|
0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866,
|
|
|
|
0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799,
|
|
|
|
0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901,
|
|
|
|
0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832,
|
|
|
|
0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768,
|
|
|
|
0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866,
|
2019-06-06 14:21:15 +02:00
|
|
|
0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901});
|
|
|
|
|
|
|
|
NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005});
|
2019-09-11 19:12:09 +02:00
|
|
|
logits.linspace(-0.08, 0.04);
|
|
|
|
weights.assign(0.5);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto *dLdw = results->at(1);
|
|
|
|
auto *dLdl = results->at(2);
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
// dLdp->printIndexedBuffer();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
// ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
|
2019-09-11 19:12:09 +02:00
|
|
|
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) {
|
|
|
|
|
|
|
|
NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0});
|
2019-12-20 20:35:39 +01:00
|
|
|
auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray numOfNonZero(sumDiff.getShapeInfo(), sd::DataType::INT64, false);
|
2019-06-06 14:21:15 +02:00
|
|
|
numOfNonZero.assign(1);
|
2019-12-20 20:35:39 +01:00
|
|
|
sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) {
|
|
|
|
|
|
|
|
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519,
|
|
|
|
0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519});
|
2019-09-11 19:12:09 +02:00
|
|
|
logits.linspace(-0.08, 0.04);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) {
|
|
|
|
|
|
|
|
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785,
|
|
|
|
0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785});
|
2019-09-11 19:12:09 +02:00
|
|
|
logits.linspace(-0.08, 0.04);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {1});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) {
|
|
|
|
|
|
|
|
NDArray labels('c', {2,3}, {1,0,0, 0,1,1});
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray logits('c', {2,3}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004});
|
2019-09-11 19:12:09 +02:00
|
|
|
logits.linspace(-0.08, 0.04);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
|
|
|
|
|
|
|
|
NDArray labels('c', {2,1}, {1,1});
|
|
|
|
NDArray logits('c', {2,1}, {-0.04, 0.04});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
NDArray dLdpExp('c', {2,1}, {0., 0.});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {1});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
|
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray labels('c', {2,1}, std::vector<double>{1,0});
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray logits('c', {2,1}, {-0.04, 0.04});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
|
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray labels('c', {1,2}, {1,1.});
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray logits('c', {1,2}, {-0.04, 0.04});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdpExp('c', {1,2}, {0, 0.});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
|
|
|
|
|
|
|
|
NDArray labels('c', {2}, {0,1});
|
|
|
|
NDArray logits('c', {2}, {-0.04, 0.04});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2}, {0.48001, -0.48001});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
|
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray labels('c', {1}, std::vector<double>{1});
|
|
|
|
NDArray logits('c', {1}, std::vector<double>{0.04});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor more corrections to support convertors
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading
* StringUtils for utf convertor #8613 tests added some corrections to optimize build
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some corrections and code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion
* StringUtils for utf convertor #8613 some fixes and tets
* StringUtils for utf convertor #8613 some more staff to support different unicode
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fix linking bug
* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed
* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing
* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)
* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master and sync with server
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up
* StringUtils for utf convertor #8613 fixed cast and copy issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda and update tests
* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc
* - avoid ambiguity of NDArray ctrs overloading in some tests
Signed-off-by: Yurii <iuriish@yahoo.com>
* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 several more fixes, improvements and updates
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 minor fixes string element size define
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 revert last changes as mistake
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
|
|
|
NDArray dLdpExp('c', {1}, std::vector<double>{0});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&logits, &labels}, {}, {0});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray x('c', {3,4,5}, sd::DataType::DOUBLE);
|
|
|
|
NDArray y('c', {1,1,1}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray dLdp('c', {3,4,5}, sd::DataType::DOUBLE);
|
|
|
|
NDArray dLdpExp('c', {3,4,5}, sd::DataType::DOUBLE);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
x.assign(1.0);//linspace(0.1, 0.1);
|
|
|
|
y.assign(1.0);
|
|
|
|
dLdp.assign(1.0);
|
|
|
|
dLdpExp.assign(1.0);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::multiply_bp op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&x, &y, &dLdp}, {}, {});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
|
|
|
auto *dLdo = results->at(0);
|
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdo));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdo));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2}, {2,1}, sd::DataType::INT64);
|
|
|
|
NDArray logits('c', {2,3}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
logits.linspace(0.1, 0.1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2}, {0,1}, sd::DataType::INT64);
|
|
|
|
NDArray logits('c', {2,3}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717});
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
logits.linspace(-0.1, 0.1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {}, std::vector<double>{1}, sd::DataType::INT64);
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray logits('c', {2}, {-0.2, 0.3});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, sd::DataType::INT64);
|
|
|
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865,
|
|
|
|
0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865});
|
|
|
|
logits.linspace(-0.5, 0.1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////
|
|
|
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
NDArray labels('c', {1,1}, std::vector<double>({0}), sd::DataType::INT64);
|
2019-06-06 14:21:15 +02:00
|
|
|
NDArray logits('c', {1,1,2}, {-0.3,0.2});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
|
|
|
NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
auto results = op.evaluate({&labels, &logits}, {}, {});
|
2019-09-11 19:12:09 +02:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
|
2019-09-11 19:12:09 +02:00
|
|
|
auto *dLdp = results->at(0);
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
|
|
|
|
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
|
|
|
|
|
|
|
|
delete results;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|