cavis/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp

3713 lines
185 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 8/4/2018.
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
using namespace nd4j;
class DeclarableOpsTests11 : public testing::Test {
public:
DeclarableOpsTests11() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});
auto y = NDArrayFactory::create<int>('c',{2}, {3, 1});
nd4j::ops::listdiff op;
auto result = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test1) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444,
-0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test2) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{-227.77286});
2019-06-06 14:21:15 +02:00
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test4) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
// dLdw->printIndexedBuffer();
// dLdw->printShapeInfo();
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test5) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308,
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541,
0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698});
NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334,
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test6) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{0.});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test8) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769,
-2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991});
NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 ,
1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004,
-0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test9) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154,
-1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996});
NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727,
-0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393});
NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167,
-0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {1,1}, std::vector<double>{-9.49054});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test11) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test12) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846,
-1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956});
NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072,
-0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 ,
-0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993});
NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383});
NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::log_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
NDArray input = NDArrayFactory::create<float>('c', {1, 7, 7, 1}, {
1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f,
8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f,
15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,
30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f,
44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
});
NDArray expected = NDArrayFactory::create<float>('c', {1, 30, 30, 1}, {
1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, 2.3283265f,
2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, 3.5441515f, 3.7380352f,
3.948995f, 4.248106f, 4.5073795f, 4.6843743f, 4.8572845f, 5.104302f,
5.3869915f, 5.581401f, 5.7539616f, 5.974285f, 6.272836f, 6.5204263f,
6.718899f, 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f,
2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, 3.5745337f,
3.78931f, 3.9656973f, 4.186417f, 4.5046535f, 4.740569f, 4.9217057f,
5.133866f, 5.459533f, 5.7744613f, 6.0197873f, 6.254011f, 6.535633f,
6.8097296f, 6.9607787f, 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f,
7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f,
3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, 4.9556503f,
5.160583f, 5.3258467f, 5.535462f, 5.84216f, 6.058749f, 6.223753f,
6.437597f, 6.797369f, 7.1836042f, 7.5164022f, 7.8290343f, 8.154773f,
8.417635f, 8.512958f, 8.5521f, 8.649708f, 8.87788f, 9.108794f,
9.320926f, 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f,
5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, 6.6005697f,
6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, 7.644651f, 7.794562f,
8.009684f, 8.400473f, 8.851847f, 9.26469f, 9.649218f, 10.015648f,
10.268647f, 10.313368f, 10.2843275f, 10.319379f, 10.512033f, 10.734956f,
10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f,
7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, 8.638187f,
8.827649f, 8.976217f, 9.168955f, 9.45726f, 9.6442375f, 9.784517f,
9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f,
12.420712f, 12.4374485f, 12.370511f, 12.371386f, 12.545973f, 12.766424f,
12.992249f, 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f,
9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, 10.78547f,
10.974367f, 11.123442f, 11.31637f, 11.603645f, 11.790616f, 11.930889f,
12.144082f, 12.546447f, 13.024898f, 13.4723f, 13.889232f, 14.276275f,
14.528972f, 14.555555f, 14.50145f, 14.515459f, 14.700572f, 14.927055f,
15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f,
10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f,
12.43254f, 12.588294f, 12.787534f, 13.079956f, 13.27752f, 13.426631f,
13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f,
15.81343f, 15.881828f, 15.883522f, 15.950411f, 16.16933f, 16.40794f,
16.636436f, 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f,
12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f,
13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f,
14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f,
16.8099f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f,
17.892765f, 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f,
13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f,
15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f,
16.50487f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f,
18.118288f, 18.296928f, 18.4461f, 18.651634f, 18.956806f, 19.22382f,
19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f,
15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.14954f,
17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f,
18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f,
20.13878f, 20.35177f, 20.546844f, 20.795671f, 21.128067f, 21.404358f,
21.626736f, 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f,
17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, 18.738056f,
18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f,
20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f,
21.689209f, 21.911621f, 22.119276f, 22.37999f, 22.71991f, 22.998823f,
23.22097f, 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f,
18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, 19.952137f,
20.164913f, 20.345781f, 20.569134f, 20.88284f, 21.12133f, 21.30459f,
21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.37289f, 22.626648f,
22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f,
24.431519f, 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f,
20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f,
21.589132f, 21.768297f, 21.99003f, 22.302366f, 22.538124f, 22.719105f,
22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.83589f, 24.096842f,
24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.62813f,
25.850672f, 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f,
22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, 23.57634f,
23.787327f, 23.96576f, 24.186796f, 24.498543f, 24.733124f, 24.913122f,
25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f,
26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f,
28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f,
24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, 25.642756f,
25.853743f, 26.032173f, 26.25321f, 26.564959f, 26.79954f, 26.97954f,
27.181242f, 27.47763f, 27.74168f, 27.929441f, 28.117207f, 28.381254f,
28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f,
30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f,
26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, 27.213314f,
27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.3701f, 28.550098f,
28.7518f, 29.04819f, 29.312237f, 29.5f, 29.687763f, 29.951813f,
30.2482f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f,
31.683659f, 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f,
27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f,
28.994865f, 29.173294f, 29.39433f, 29.70608f, 29.940659f, 30.120655f,
30.32236f, 30.618746f, 30.882797f, 31.070557f, 31.25832f, 31.522371f,
31.818754f, 32.02046f, 32.20046f, 32.43504f, 32.758415f, 33.031567f,
33.25422f, 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f,
29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, 30.85029f,
31.061283f, 31.239712f, 31.46075f, 31.7725f, 32.00708f, 32.187077f,
32.38878f, 32.685165f, 32.949215f, 33.13698f, 33.32474f, 33.58879f,
33.885178f, 34.086884f, 34.26688f, 34.501457f, 34.824837f, 35.09799f,
35.320637f, 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f,
31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, 33.046757f,
33.257744f, 33.436176f, 33.657207f, 33.96896f, 34.203537f, 34.383537f,
34.58524f, 34.88163f, 35.145676f, 35.33344f, 35.521206f, 35.785255f,
36.081642f, 36.28334f, 36.46334f, 36.69792f, 37.021297f, 37.294453f,
37.517097f, 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f,
33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, 34.466957f,
34.677948f, 34.856377f, 35.077415f, 35.38916f, 35.623745f, 35.803745f,
36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f,
37.50184f, 37.703545f, 37.883545f, 38.118122f, 38.4415f, 38.714653f,
38.9373f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f,
34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, 35.6781f,
35.889088f, 36.067516f, 36.28855f, 36.6003f, 36.834885f, 37.014877f,
37.216583f, 37.51297f, 37.77702f, 37.964783f, 38.152546f, 38.416595f,
38.71298f, 38.914684f, 39.094685f, 39.32926f, 39.652645f, 39.925793f,
40.14844f, 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f,
36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, 37.271378f,
37.48237f, 37.6608f, 37.881836f, 38.19359f, 38.42817f, 38.608162f,
38.809868f, 39.10625f, 39.3703f, 39.558064f, 39.74583f, 40.00988f,
40.306267f, 40.50797f, 40.68797f, 40.92255f, 41.245926f, 41.519077f,
41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f,
38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.52832f,
39.739307f, 39.917736f, 40.138775f, 40.45052f, 40.685104f, 40.865097f,
41.066803f, 41.36319f, 41.627243f, 41.815002f, 42.002766f, 42.26682f,
42.5632f, 42.764908f, 42.944904f, 43.179485f, 43.50286f, 43.776016f,
43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f,
40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, 41.440395f,
41.651382f, 41.82982f, 42.050854f, 42.3626f, 42.597183f, 42.77718f,
42.97888f, 43.27527f, 43.53932f, 43.72708f, 43.914845f, 44.178894f,
44.47528f, 44.676983f, 44.856983f, 45.09156f, 45.41494f, 45.68809f,
45.91074f, 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f,
41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, 42.998936f,
43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.15572f, 44.335716f,
44.53742f, 44.833805f, 45.09786f, 45.285614f, 45.473377f, 45.737427f,
46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.24663f,
47.469276f, 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f,
43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f,
44.539444f, 44.717873f, 44.93891f, 45.25066f, 45.48524f, 45.665237f,
45.86694f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f,
47.363342f, 47.56505f, 47.74505f, 47.979626f, 48.302998f, 48.576153f,
48.798798f, 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f,
44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.51718f,
45.72817f, 45.9066f, 46.12764f, 46.439384f, 46.673965f, 46.853966f,
47.055668f, 47.352055f, 47.6161f, 47.803867f, 47.99163f, 48.25568f,
48.552063f, 48.75377f, 48.933773f, 49.16835f, 49.491726f, 49.764877f,
49.987526f, 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f,
44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, 45.98499f,
46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f,
47.523476f, 47.819862f, 48.08391f, 48.27168f, 48.459446f, 48.72349f,
49.019882f, 49.22158f, 49.401585f, 49.63616f, 49.959538f, 50.232693f,
50.455338f, 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f,
44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, 45.82328f,
46.03427f, 46.2127f, 46.433743f, 46.74549f, 46.98007f, 47.160065f,
47.36177f, 47.658157f, 47.922207f, 48.10997f, 48.297733f, 48.561783f,
48.858166f, 49.059875f, 49.239872f, 49.47445f, 49.79783f, 50.07098f,
50.293625f, 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f,
44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, 45.43256f,
45.643543f, 45.82198f, 46.04302f, 46.354763f, 46.589344f, 46.76934f,
46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.17105f,
48.467438f, 48.66914f, 48.849144f, 49.08372f, 49.4071f, 49.680256f,
49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 4, 3});
NDArray expected = NDArrayFactory::create<float>('c', {2, 10, 8, 3}, {
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f,
23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f,
28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f,
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
34.281250f, 35.281250f, 36.281250f, 31.000000f, 32.000000f, 33.000000f, 32.218750f, 33.218750f, 34.218750f,
34.000000f, 35.000000f, 36.000000f, 35.500000f, 36.500000f, 37.500000f, 37.000000f, 38.000000f, 39.000000f,
38.781250f, 39.781250f, 40.781250f, 40.000000f, 41.000000f, 42.000000f, 40.281250f, 41.281250f, 42.281250f,
37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f,
41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f,
46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 44.125000f, 45.125000f, 46.125000f,
45.343750f, 46.343750f, 47.343750f, 47.125000f, 48.125000f, 49.125000f, 48.625000f, 49.625000f, 50.625000f,
50.125000f, 51.125000f, 52.125000f, 51.906250f, 52.906250f, 53.906250f, 53.125000f, 54.125000f, 55.125000f,
53.406250f, 54.406250f, 55.406250f, 49.000000f, 50.000000f, 51.000000f, 50.218750f, 51.218750f, 52.218750f,
52.000000f, 53.000000f, 54.000000f, 53.500000f, 54.500000f, 55.500000f, 55.000000f, 56.000000f, 57.000000f,
56.781250f, 57.781250f, 58.781250f, 58.000000f, 59.000000f, 60.000000f, 58.281250f, 59.281250f, 60.281250f,
50.125000f, 51.125000f, 52.125000f, 51.343750f, 52.343750f, 53.343750f, 53.125000f, 54.125000f, 55.125000f,
54.625000f, 55.625000f, 56.625000f, 56.125000f, 57.125000f, 58.125000f, 57.906250f, 58.906250f, 59.906250f,
59.125000f, 60.125000f, 61.125000f, 59.406250f, 60.406250f, 61.406250f, 61.000000f, 62.000000f, 63.000000f,
62.218750f, 63.218750f, 64.218750f, 64.000000f, 65.000000f, 66.000000f, 65.500000f, 66.500000f, 67.500000f,
67.000000f, 68.000000f, 69.000000f, 68.781250f, 69.781250f, 70.781250f, 70.000000f, 71.000000f, 72.000000f,
70.281250f, 71.281250f, 72.281250f, 65.875000f, 66.875000f, 67.875000f, 67.093750f, 68.093750f, 69.093750f,
68.875000f, 69.875000f, 70.875000f, 70.375000f, 71.375000f, 72.375000f, 71.875000f, 72.875000f, 73.875000f,
73.656250f, 74.656250f, 75.656250f, 74.875000f, 75.875000f, 76.875000f, 75.156250f, 76.156250f, 77.156250f,
73.000000f, 74.000000f, 75.000000f, 74.218750f, 75.218750f, 76.218750f, 76.000000f, 77.000000f, 78.000000f,
77.500000f, 78.500000f, 79.500000f, 79.000000f, 80.000000f, 81.000000f, 80.781250f, 81.781250f, 82.781250f,
82.000000f, 83.000000f, 84.000000f, 82.281250f, 83.281250f, 84.281250f, 79.000000f, 80.000000f, 81.000000f,
80.218750f, 81.218750f, 82.218750f, 82.000000f, 83.000000f, 84.000000f, 83.500000f, 84.500000f, 85.500000f,
85.000000f, 86.000000f, 87.000000f, 86.781250f, 87.781250f, 88.781250f, 88.000000f, 89.000000f, 90.000000f,
88.281250f, 89.281250f, 90.281250f, 85.000000f, 86.000000f, 87.000000f, 86.218750f, 87.218750f, 88.218750f,
88.000000f, 89.000000f, 90.000000f, 89.500000f, 90.500000f, 91.500000f, 91.000000f, 92.000000f, 93.000000f,
92.781250f, 93.781250f, 94.781250f, 94.000000f, 95.000000f, 96.000000f, 94.281250f, 95.281250f, 96.281250f,
91.000000f, 92.000000f, 93.000000f, 92.218750f, 93.218750f, 94.218750f, 94.000000f, 95.000000f, 96.000000f,
95.500000f, 96.500000f, 97.500000f, 97.000000f, 98.000000f, 99.000000f, 98.781250f, 99.781250f, 100.781250f,
100.000000f, 101.000000f, 102.000000f, 100.281250f, 101.281250f, 102.281250f, 97.000000f, 98.000000f,
99.000000f, 98.218750f, 99.218750f, 100.218750f, 100.000000f, 101.000000f, 102.000000f, 101.500000f,
102.500000f, 103.500000f, 103.000000f, 104.000000f, 105.000000f, 104.781250f, 105.781250f, 106.781250f,
106.000000f, 107.000000f, 108.000000f, 106.281250f, 107.281250f, 108.281250f, 104.125000f, 105.125000f,
106.125000f, 105.343750f, 106.343750f, 107.343750f, 107.125000f, 108.125000f, 109.125000f, 108.625000f,
109.625000f, 110.625000f, 110.125000f, 111.125000f, 112.125000f, 111.906250f, 112.906250f, 113.906250f,
113.125000f, 114.125000f, 115.125000f, 113.406250f, 114.406250f, 115.406250f, 109.000000f, 110.000000f,
111.000000f, 110.218750f, 111.218750f, 112.218750f, 112.000000f, 113.000000f, 114.000000f, 113.500000f,
114.500000f, 115.500000f, 115.000000f, 116.000000f, 117.000000f, 116.781250f, 117.781250f, 118.781250f,
118.000000f, 119.000000f, 120.000000f, 118.281250f, 119.281250f, 120.281250f, 110.125000f, 111.125000f,
112.125000f, 111.343750f, 112.343750f, 113.343750f, 113.125000f, 114.125000f, 115.125000f, 114.625000f,
115.625000f, 116.625000f, 116.125000f, 117.125000f, 118.125000f, 117.906250f, 118.906250f, 119.906250f,
119.125000f, 120.125000f, 121.125000f, 119.406250f, 120.406250f, 121.406250f
}); //input = 1.f;
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
input.linspace(1);
auto size = NDArrayFactory::create<int>({10, 8});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 10x8");
// expected.printBuffer("Expect for 10x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 4}, {
1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, 4.625000f, 5.625000f, 5.000000f,
6.000000f, 7.000000f, 8.000000f, 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f,
11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, 5.875000f, 6.875000f, 7.875000f,
8.875000f, 7.500000f, 8.500000f, 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f,
12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, 15.875000f, 16.875000f, 14.250000f,
15.250000f, 16.250000f, 17.250000f, 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f,
16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, 19.375000f, 20.375000f, 21.375000f,
22.375000f, 21.000000f, 22.000000f, 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f,
20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, 23.750000f, 24.750000f, 24.125000f,
25.125000f, 26.125000f, 27.125000f, 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f,
30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, 25.000000f, 26.000000f, 27.000000f,
28.000000f, 26.625000f, 27.625000f, 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f,
31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, 35.000000f, 36.000000f, 33.375000f,
34.375000f, 35.375000f, 36.375000f, 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f,
29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, 32.500000f, 33.500000f, 34.500000f,
35.500000f, 34.125000f, 35.125000f, 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x6");
// expected.printBuffer("Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 4, 3});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 8, 3}, {
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, 23.125000f, 24.125000f, 25.125000f,
24.625000f, 25.625000f, 26.625000f, 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f,
29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, 25.000000f, 26.000000f, 27.000000f,
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
34.281250f, 35.281250f, 36.281250f, 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f,
29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, 32.125000f, 33.125000f, 34.125000f,
33.906250f, 34.906250f, 35.906250f, 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f
});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 8});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x8");
// expected.printBuffer("Expect for 6x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
NDArray input = NDArrayFactory::create<double>('c', {1, 4, 4, 3});
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 8, 3}, {
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f,
23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f,
28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f,
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
34.281250f, 35.281250f, 36.281250f, 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, 35.343750f,
35.125000f, 36.125000f, 37.125000f, 36.625000f, 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f,
39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, 43.125000f, 41.406250f, 42.406250f, 43.406250f,
37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f,
41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f,
46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, 40.125000f,
39.343750f, 40.343750f, 41.343750f, 41.125000f, 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f,
44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, 47.906250f, 47.125000f, 48.125000f, 49.125000f,
47.406250f, 48.406250f, 49.406250f,
});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
input.linspace(1);
auto size = NDArrayFactory::create<int>({8, 8});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 8x8");
// expected.printBuffer("Expect for 8x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
NDArray input = NDArrayFactory::create<float>('c', {7, 7, 1}, {
1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f,
8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f,
15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,
30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f,
44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
});
NDArray expected = NDArrayFactory::create<float>('c', {30, 30, 1}, {
1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, 2.550918f, 2.736061f, 2.965541f,
3.292965f, 3.544151f, 3.738035f, 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f,
5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, 6.718899f, 6.887104f, 7.039068f,
7.099216f, 7.078424f, 7.028189f, 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f,
3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, 5.133866f, 5.459533f, 5.774461f,
6.019787f, 6.254011f, 6.535633f, 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f,
7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, 3.628684f, 3.830573f, 4.056959f,
4.321157f, 4.636486f, 4.955650f, 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f,
6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, 8.417635f, 8.512958f, 8.552100f,
8.649708f, 8.877880f, 9.108794f, 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f,
5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, 6.796207f, 6.951142f, 7.150400f,
7.446143f, 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f,
10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, 11.154507f, 11.315369f,
11.374779f, 11.354242f, 11.304622f, 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f,
8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, 9.999621f, 10.407702f, 10.896234f,
11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f,
12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, 9.493208f, 9.692467f, 9.916944f,
10.176801f, 10.482199f, 10.785470f, 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f,
12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, 14.501450f,
14.515459f, 14.700572f, 14.927055f, 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f,
10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, 12.432540f, 12.588294f, 12.787534f,
13.079956f, 13.277520f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f,
15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, 16.636436f, 16.842583f, 17.010887f,
17.073630f, 17.051940f, 16.999537f, 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f,
13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f,
15.924977f, 16.213829f, 16.532364f, 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f,
17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, 13.766397f, 13.947391f, 14.148263f,
14.386917f, 14.681246f, 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f,
16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, 18.446100f,
18.651634f, 18.956806f, 19.223820f, 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f,
15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, 17.361883f, 17.542162f, 17.764957f,
18.078188f, 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f,
20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, 21.815500f, 21.985610f,
22.052843f, 22.029604f, 21.973448f, 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f,
18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f,
20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f,
23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, 18.746353f, 18.922657f, 19.117487f,
19.350685f, 19.642070f, 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f,
21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, 22.926834f, 23.143423f, 23.343302f,
23.596668f, 23.931936f, 24.209232f, 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f,
20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, 21.589132f, 21.768297f, 21.990030f,
22.302366f, 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f,
24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, 25.850672f, 26.040140f, 26.210072f,
26.277063f, 26.253906f, 26.197956f, 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f,
23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f,
25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f,
28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, 24.429443f, 24.607670f, 24.804970f,
25.040410f, 25.333065f, 25.642756f, 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f,
27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, 29.059345f,
29.293922f, 29.617298f, 29.890451f, 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f,
26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, 27.424305f, 27.602734f, 27.823772f,
28.135519f, 28.370100f, 28.550098f, 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f,
30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, 31.873592f, 32.043407f,
32.110240f, 32.087135f, 32.031320f, 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f,
28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, 30.322360f, 30.618746f, 30.882797f,
31.070557f, 31.258320f, 31.522371f, 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f,
33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, 29.636976f, 29.815207f, 30.012500f,
30.247944f, 30.540600f, 30.850290f, 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f,
32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, 33.885178f, 34.086884f, 34.266880f,
34.501457f, 34.824837f, 35.097990f, 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f,
31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, 33.257744f, 33.436176f, 33.657207f,
33.968960f, 34.203537f, 34.383537f, 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f,
36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, 37.517097f, 37.707027f, 37.876846f,
37.943680f, 37.920578f, 37.864758f, 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f,
34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f,
36.753647f, 36.941406f, 37.205456f, 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f,
38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, 34.464783f, 34.643010f, 34.840305f,
35.075752f, 35.368404f, 35.678100f, 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f,
37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, 38.712980f, 38.914684f, 39.094685f,
39.329260f, 39.652645f, 39.925793f, 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f,
36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, 37.482370f, 37.660800f, 37.881836f,
38.193590f, 38.428170f, 38.608162f, 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f,
40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, 41.741722f, 41.931652f, 42.101475f,
42.168304f, 42.145203f, 42.089386f, 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f,
39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, 41.066803f, 41.363190f, 41.627243f,
41.815002f, 42.002766f, 42.266820f, 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f,
43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, 40.227080f, 40.405310f, 40.602608f,
40.838050f, 41.130707f, 41.440395f, 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f,
42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, 44.475280f, 44.676983f, 44.856983f,
45.091560f, 45.414940f, 45.688090f, 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f,
41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, 43.209923f, 43.388355f, 43.609394f,
43.921143f, 44.155720f, 44.335716f, 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f,
46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, 47.469276f, 47.659210f, 47.829030f,
47.895855f, 47.872753f, 47.816940f, 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f,
44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, 45.866940f, 46.163326f, 46.427376f,
46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f,
48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, 44.303867f, 44.482094f, 44.679394f,
44.914833f, 45.207493f, 45.517180f, 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f,
47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, 48.552063f, 48.753770f, 48.933773f,
49.168350f, 49.491726f, 49.764877f, 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f,
44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, 46.195976f, 46.374413f, 46.595448f,
46.907196f, 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f,
49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, 50.455338f, 50.645270f, 50.815090f,
50.881920f, 50.858818f, 50.803000f, 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f,
46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, 47.361770f, 47.658157f, 47.922207f,
48.109970f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f,
50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, 44.219246f, 44.397472f, 44.594772f,
44.830210f, 45.122868f, 45.432560f, 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f,
46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, 48.467438f, 48.669140f, 48.849144f,
49.083720f, 49.407100f, 49.680256f, 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f
});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511,
0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613,
0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989,
0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312,
0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699});
NDArray expected = NDArrayFactory::create<float>('c', {2, 9, 9, 1}, {
0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f,
0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f,
0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f,
0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f,
0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f,
0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f,
0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f,
0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f,
0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f,
0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f,
0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f,
0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f,
0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f,
0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f,
0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f,
0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f,
0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f,
0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f,
0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f,
0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f,
0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f,
0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f,
0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f,
0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f,
0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f,
0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f,
0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f,
0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f,
0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f,
0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f,
0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f,
0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f,
0.9139262f, 0.92068815f
});
auto size = NDArrayFactory::create<int>({9, 9});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 9x9");
// expected.printBuffer("Expect for 9x9");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338,
0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859,
0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232,
0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986,
0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547,
0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666,
0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496,
0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752,
0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968,
0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859});
auto testData = NDArrayFactory::create<float>('c', {2,9,9,1}, {
0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f,
0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f,
0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f,
0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f,
0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f,
0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f,
0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f,
0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f,
0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f,
0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f,
0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f,
0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f,
0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f,
0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f,
0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f,
0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f,
0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f,
0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f
});
auto size = NDArrayFactory::create<int>({9, 9});
nd4j::ops::resize_bicubic op;
auto results = op.evaluate({&input, &size}, {}, {}, {true, false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 9x9");
// testData.printBuffer("Expect for 9x9");
ASSERT_TRUE(testData.isSameShape(result));
ASSERT_TRUE(testData.equalsTo(result));
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 4}, {
1.f, 2.f, 3.f, 4.f,
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f,
5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f,
9.f, 10.f, 11.f, 12.f,
1.f, 2.f, 3.f, 4.f,
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f,
5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f,
9.f, 10.f, 11.f, 12.f,
13.f, 14.f, 15.f, 16.f,
13.f, 14.f, 15.f, 16.f,
17.f, 18.f, 19.f, 20.f,
17.f, 18.f, 19.f, 20.f,
21.f, 22.f, 23.f, 24.f,
21.f, 22.f, 23.f, 24.f,
13.f, 14.f, 15.f, 16.f,
13.f, 14.f, 15.f, 16.f,
17.f, 18.f, 19.f, 20.f,
17.f, 18.f, 19.f, 20.f,
21.f, 22.f, 23.f, 24.f,
21.f, 22.f, 23.f, 24.f,
25.f, 26.f, 27.f, 28.f,
25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f,
29.f, 30.f, 31.f, 32.f,
33.f, 34.f, 35.f, 36.f,
33.f, 34.f, 35.f, 36.f,
25.f, 26.f, 27.f, 28.f,
25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f,
29.f, 30.f, 31.f, 32.f,
33.f, 34.f, 35.f, 36.f,
33.f, 34.f, 35.f, 36.f });
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) {
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 1});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 1}, {
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
7.f, 7.f, 8.f, 8.f, 9.f, 9.f,
7.f, 7.f, 8.f, 8.f, 9.f, 9.f
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) {
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 3}, {
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) {
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 3, 3}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 3}, {
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
});
//input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) {
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 3}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 3}, {
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
});
//input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) {
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 1}, {
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
});
//input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {}, {true});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) {
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 1}, {
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
});
//input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {6, 6}, {true});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) {
NDArray input = NDArrayFactory::create<int>('c', {1, 3, 3, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9
});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 1}, {
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
});
//input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {6, 6}, {true});
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
});
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {
1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000
});
//input.linspace(1);
auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 10x10");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
});
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {
1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000
});
//input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {10, 10});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 10x10");
// expected.printBuffer("Area Expect for 6x6");
Shugeo resize area (#162) * Added implementation for resize_area op. Initial commit. * Added implementation of resize_area op. Initial revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected resizeArea functor call. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_area. Cpu platform helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation for resize_area helpers. The first part revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a set of tests for resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation for resize_area. Initial approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Adding multithreading for resize_area algorithm. Signed-off-by: shugeo <sgazeos@gmail.com> * Cuda implementation of resize_area helpers. Shared memory approach. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resizeAreaKernel with cuda implementation. * Eliminated compilation errors. * ResizeArea helpers for cuda platform. The first working revision. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for batched resize_area op testing. Signed-off-by: shugeo <sgazeos@gmail.com> * Implementation of resize_are for cuda platform and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_area op helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright marks with sources. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for resize_area op implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected copyright mark for parity ops header. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected typo in strings and so on with image resize ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored resize_area helpers and multithreading. Signed-off-by: shugeo <sgazeos@gmail.com> * Added ResizeArea wrapper * Added test with align_corners and fixed shape processing with only int args given for output size. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test * TF mapping for ResizeArea * Fixed implementation issues with resize_area op for both platforms. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored image resizer struct to use flexible types for ints and floats. Signed-off-by: shugeo <sgazeos@gmail.com> * Improved multithreading with resizeAreaKernel launch. Signed-off-by: shugeo <sgazeos@gmail.com> * Use asynchronical memory copying with cuda platform image resize allocations. Signed-off-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-01-22 08:46:33 +01:00
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
});
// NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 9, 4}, {
// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.33333
//
// });
//input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {6, 9});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x9");
// expected.printBuffer("Area Expect for 6x6");
// ASSERT_TRUE(expected.isSameShape(result));
// ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
});
// NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 9, 4}, {
// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.33333
//
// });
//input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {10, 15});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x9");
// expected.printBuffer("Area Expect for 6x6");
// ASSERT_TRUE(expected.isSameShape(result));
// ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24
});
// NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 8, 4}, {
// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.33333
//
// });
//input.linspace(1);
//auto size = NDArrayFactory::create<int>({10, 10});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {9, 9});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 8x8");
// expected.printBuffer("Area Expect for 6x6");
// ASSERT_TRUE(expected.isSameShape(result));
// ASSERT_TRUE(expected.equalsTo(result));
Shugeo image resize bicubic (#56) * Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * Added check for input/output sizes. * Added int and float types for crop_and_resize op. * Refactored crop_and_resize output type check. * Added helper for bicubic interpolation as TF v.1 does. * Added TF v.1 bicubic helper for cuda platform. * Added cached class for bicubic algorithm. * Refactored cuda implementation for crop_and_resize helper to use proper output type. * Added facilities for bicubic interpolation. * Portion bicubic interpolation from TF. * Added tests for resize_bilinear testing. * Working implementation of bicubic interpolation and tests. * Refactored routines with image_resize bicubic op helper. * Refactored code with coding standards. * Refactored cpu helpers for resize_bicubic op. * Refactored bicubic helpers. * Added bicubic resize facilities. * Implementing cuda kernels for bicubic interpolation. Implementation step. * Cuda implementation of resize_bicubic op helper. * Refactor image.resize_bicubic op helpers. * Refactored helpers for resize_bicubic. Added error checking with cuda implementation. * Refactored cuda implementation of resize_bicubic op helper. The first working revision. * Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision. * Intermediate bicubic interpolation helper for cuda. * Refactored cpu helper for resize_bicubic. * Multithreaded cuda implementation for resize_bicubic. * Fixed merge issues. * Refactored nlp helpers. * Replicated resize_bicubic for 3D also. * Eliminated waste comments of unused code. * Eliminated waste comments with unused code. * Eliminated waste template definitions. * Eliminated waste debug code. * Eliminated waste comments. * Fixed multithreading with helpers. * Fixed test suites for float and double in float point input lists. * Fixed usage of reshape with 3D/4D on resizes. * Final fixes. * Fixed resize_neighbor op problem.
2019-11-20 20:11:04 +01:00
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
2019-06-06 14:21:15 +02:00
functions::summarystats::SummaryStatsData<double> var1;
functions::summarystats::SummaryStatsData<double> var2;
var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5;
2019-06-06 14:21:15 +02:00
functions::summarystats::SummaryStatsData<double>* arr = new functions::summarystats::SummaryStatsData<double>[2];
arr[0] = var1;
arr[1] = var2;
arr[0] = arr[1];
functions::summarystats::SummaryStatsData<double> var3(var1);
ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5);
ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5);
ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0);
delete []arr;
}
Shugeo solve linear (#191) * linear equations systems solve op. Initial commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed compiling issues. Signed-off-by: shugeo <sgazeos@gmail.com> * Linear equations systems solve. The next stage commit. Signed-off-by: shugeo <sgazeos@gmail.com> * Added test for linear equations systems solve operation. Signed-off-by: shugeo <sgazeos@gmail.com> * Added additional test and fixed lower matrix retrievance. * Implementation for solve of the systems of linear equations." Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored permutation generation. Signed-off-by: shugeo <sgazeos@gmail.com> * Added restore for permutations batched with cuda helper for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Finished cuda implementation for solve op helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored cpu helpers for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fix gtest output on Windows * Fixed issue with permutation matrix for cuda implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed issue with permutation matrix for cpu implementation. Signed-off-by: shugeo <sgazeos@gmail.com> * Eliminated waste comments. Signed-off-by: shugeo <sgazeos@gmail.com> * LinearSolve added * Mapping added * Javadoc added * Refactored implementation of triangular_solve helpers and tests for solve matrix equations generally. Signed-off-by: shugeo <sgazeos@gmail.com> * Added a test for solve op. Signed-off-by: shugeo <sgazeos@gmail.com> * Solve test added * Fix for TF import Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
2020-02-04 06:59:11 +01:00
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_1) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f
});
auto b = NDArrayFactory::create<float>('c', {3, 1}, {
2.f, 4.f, 3.f
});
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {
7.625f, 3.25f, 5.f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("Solve of 3x3");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_2) {
auto a = NDArrayFactory::create<float>('c', {4, 4}, {
1.f, 1.f, 1.f, 1.f,
0.f, 1.f, 1.f, 0.f,
0.f, 0.f, 2.f, 1.f,
0.f, 0.f, 0.f, 3.f,
});
auto b = NDArrayFactory::create<float>('c', {4, 1}, {
2.f, 4.f, 2.f, 4.f
});
auto exp = NDArrayFactory::create<float>('c', {4, 1}, {
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("Solve 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_3) {
auto a = NDArrayFactory::create<float>('c', {2, 4, 4}, {
1.f, 1.f, 1.f, 1.f,
0.f, 1.f, 1.f, 0.f,
0.f, 0.f, 2.f, 1.f,
0.f, 0.f, 0.f, 3.f,
3.f, 0.f, 0.f, 0.f,
2.f, 1.f, 0.f, 0.f,
1.f, 0.f, 1.f, 0.f,
1.f, 1.f, 1.f, 1.f
});
auto b = NDArrayFactory::create<float>('c', {2, 4, 1}, {
2.f, 4.f, 2.f, 4.f,
4.f, 2.f, 4.f, 2.f
});
auto exp = NDArrayFactory::create<float>('c', {2, 4, 1}, {
-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f,
1.333333f, -0.6666667f, 2.6666667f, -1.3333333f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printIndexedBuffer("Solve 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_4) {
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f
});
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
0.7717f, 0.9281f, 0.9846f, 0.4838f,
0.6433f, 0.6041f, 0.6501f, 0.7612f
});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f,
// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f
1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f,
0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
// z->printBuffer("4 Solve 4x4");
// exp.printBuffer("4 Expec 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f
});
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
0.7717f, 0.9281f, 0.9846f,
0.4838f, 0.6433f, 0.6041f,
0.6501f, 0.7612f, 0.7605f
});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
1.5504692f, 1.8953944f, 2.2765768f,
0.03399149f, 0.2883001f, 0.5377323f,
-0.8774802f, -1.2155888f, -1.8049058f
});
nd4j::ops::solve op;
auto res = op.evaluate({&a, &b}, {true});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
auto z = res->at(0);
z->printBuffer("4 Solve 4x4");
exp.printBuffer("4 Expec 4x4");
ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 ,
155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto dLdw = results->at(1);
auto dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{4515.84});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96,
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208,
-2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{0.});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152,
-1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304});
NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992,
-6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 });
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48,
-0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96});
NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296,
6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {1,1}, std::vector<double>{188.16});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576,
-0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152});
NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552,
7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
-1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92});
NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::mean_sqerr_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) {
auto x = NDArrayFactory::create<float>('c', {4}, {0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {4}, {9, 1,1, 9});
nd4j::ops::squaredsubtract op;
auto result = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) {
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9});
nd4j::ops::squaredsubtract op;
auto result = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) {
auto x = NDArrayFactory::create<float>('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c',{4}, {3, 2, 1, 0});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48});
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1,2,3,4,5,6,7,8});
nd4j::ops::squaredsubtract_bp op;
auto result = op.evaluate({&x, &y, &eps}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52,
12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto dLdw = results->at(1);
auto dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{288.});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,
-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167});
NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04,
0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{0.});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,
-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05});
NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 ,
-0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,
-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083});
NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48,
0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {1,1}, std::vector<double>{12.});
2019-06-06 14:21:15 +02:00
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,
-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025});
NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576,
0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.,
-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167});
NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 });
predictions.linspace(0.04, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::absolute_difference_loss_grad op;
auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl));
ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_1) {
2019-06-06 14:21:15 +02:00
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
2019-06-06 14:21:15 +02:00
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
2019-06-06 14:21:15 +02:00
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_2) {
NDArray x = NDArrayFactory::create<float16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_3) {
NDArray x('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray y('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548,
-6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077});
NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074,
-4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113});
NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18,
-0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,1,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{-91.52109});
2019-06-06 14:21:15 +02:00
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936});
2019-06-06 14:21:15 +02:00
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171,
-0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715});
NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539,
0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881});
NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105,
-0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{0.});
2019-06-06 14:21:15 +02:00
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805,
-0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258});
NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246,
0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258});
NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126,
-0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.p(2, 0.);
weights.p(3, 0.);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585,
-0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857});
NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 ,
-0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 });
NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525,
-0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
2019-06-06 14:21:15 +02:00
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {1,1}, std::vector<double>{-3.81338});
2019-06-06 14:21:15 +02:00
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdw = results->at(1);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902,
-0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629});
NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944,
-0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196});
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063,
-0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
weights.t<double>(3) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) {
2019-06-06 14:21:15 +02:00
NDArray labels('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2,3,1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
-0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715});
NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,});
NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.,
-0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
weights.assign(0.5);
weights.t<double>(0) = 0.;
weights.t<double>(1) = 0.;
weights.t<double>(2) = 0.;
2019-06-06 14:21:15 +02:00
nd4j::ops::sigm_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
ASSERT_TRUE(dLdlExp.isSameShape(dLdl));
ASSERT_TRUE(dLdlExp.equalsTo(dLdl));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_4) {
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(1);
y.linspace(1);
exp.linspace(2,2);
nd4j::ops::add op;
auto results = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_5) {
NDArray x = NDArrayFactory::create<float>('c', {2,3,4});
NDArray y = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<float>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(2, 2);
y.linspace(1);
exp.linspace(1);
nd4j::ops::subtract op;
auto results = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, BFloat16_Test_6) {
NDArray x = NDArrayFactory::create<bfloat16>('c', {2,3,4});
NDArray y = NDArrayFactory::create<double>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
NDArray exp = NDArrayFactory::create<bfloat16>('c', {2,3,4});//('c', {2,3,4}, nd4j::DataType::BFLOAT16);
x.linspace(2, 2);
y.linspace(1);
exp.linspace(1);
nd4j::ops::subtract op;
auto results = op.evaluate({&x, &y}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto res = results->at(0);
ASSERT_TRUE(res->equalsTo(exp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) {
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326});
NDArray dLdwExp('c', {2}, {1.36729, 1.40729});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
2019-06-06 14:21:15 +02:00
delete results;
}
2019-06-06 14:21:15 +02:00
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {1}, std::vector<double>{1.38629});
2019-06-06 14:21:15 +02:00
logits = 2.;
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
2019-06-06 14:21:15 +02:00
delete results;
}
2019-06-06 14:21:15 +02:00
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray weights('c', {}, std::vector<double>{0}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{1.38629});
2019-06-06 14:21:15 +02:00
logits = 2.;
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
2019-06-06 14:21:15 +02:00
delete results;
}
2019-06-06 14:21:15 +02:00
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray weights('c', {}, std::vector<double>{0}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {}, std::vector<double>{0.});
2019-06-06 14:21:15 +02:00
logits.linspace(-0.08, 0.04);
weights = 0.5;
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
2019-06-06 14:21:15 +02:00
delete results;
}
2019-06-06 14:21:15 +02:00
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdwExp('c', {1}, std::vector<double>{1.36729});
2019-06-06 14:21:15 +02:00
logits.linspace(-0.08, 0.04);
weights = 0.5;
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
2019-06-06 14:21:15 +02:00
delete results;
}
2019-06-06 14:21:15 +02:00
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) {
NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951});
NDArray dLdwExp('c', {2}, {-0.014000, 0.014000});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, nd4j::DataType::INT32);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,3}, {0.5, 0., 1.5});
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945,
0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945});
NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365});
logits.linspace(-0.08, 0.04);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
2019-06-06 14:21:15 +02:00
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) {
NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,
0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,
0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, nd4j::DataType::INT32);
2019-06-06 14:21:15 +02:00
NDArray logits('c', {2,3,4,5}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335,
0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399,
0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866,
0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799,
0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901,
0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832,
0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768,
0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866,
2019-06-06 14:21:15 +02:00
0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901});
NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005});
logits.linspace(-0.08, 0.04);
weights.assign(0.5);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_grad op;
auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
auto *dLdw = results->at(1);
auto *dLdl = results->at(2);
// dLdp->printIndexedBuffer();
2019-06-06 14:21:15 +02:00
// ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
// ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
ASSERT_TRUE(dLdwExp.isSameShape(dLdw));
ASSERT_TRUE(dLdwExp.equalsTo(dLdw));
2019-06-06 14:21:15 +02:00
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) {
NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0});
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true);
2019-06-06 14:21:15 +02:00
NDArray numOfNonZero(sumDiff.getShapeInfo(), nd4j::DataType::INT64, false);
numOfNonZero.assign(1);
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff);
2019-06-06 14:21:15 +02:00
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0});
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519,
0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519});
logits.linspace(-0.08, 0.04);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) {
NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0});
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785,
0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785});
logits.linspace(-0.08, 0.04);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {2,3}, {1,0,0, 0,1,1});
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004});
logits.linspace(-0.08, 0.04);
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
NDArray labels('c', {2,1}, {1,1});
NDArray logits('c', {2,1}, {-0.04, 0.04});
NDArray dLdpExp('c', {2,1}, {0., 0.});
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {1});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray labels('c', {2,1}, std::vector<double>{1,0});
2019-06-06 14:21:15 +02:00
NDArray logits('c', {2,1}, {-0.04, 0.04});
NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999});
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray labels('c', {1,2}, {1,1.});
2019-06-06 14:21:15 +02:00
NDArray logits('c', {1,2}, {-0.04, 0.04});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdpExp('c', {1,2}, {0, 0.});
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
NDArray labels('c', {2}, {0,1});
NDArray logits('c', {2}, {-0.04, 0.04});
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2}, {0.48001, -0.48001});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray labels('c', {1}, std::vector<double>{1});
NDArray logits('c', {1}, std::vector<double>{0.04});
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray dLdpExp('c', {1}, std::vector<double>{0});
2019-06-06 14:21:15 +02:00
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&logits, &labels}, {}, {0});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) {
NDArray x('c', {3,4,5}, nd4j::DataType::DOUBLE);
NDArray y('c', {1,1,1}, nd4j::DataType::DOUBLE);
NDArray dLdp('c', {3,4,5}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {3,4,5}, nd4j::DataType::DOUBLE);
x.assign(1.0);//linspace(0.1, 0.1);
y.assign(1.0);
dLdp.assign(1.0);
dLdpExp.assign(1.0);
nd4j::ops::multiply_bp op;
auto results = op.evaluate({&x, &y, &dLdp}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdo = results->at(0);
ASSERT_TRUE(dLdpExp.isSameShape(dLdo));
ASSERT_TRUE(dLdpExp.equalsTo(dLdo));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) {
NDArray labels('c', {2}, {2,1}, nd4j::DataType::INT64);
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717});
logits.linspace(0.1, 0.1);
2019-06-06 14:21:15 +02:00
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&labels, &logits}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
NDArray labels('c', {2}, {0,1}, nd4j::DataType::INT64);
NDArray logits('c', {2,3}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717});
2019-06-06 14:21:15 +02:00
logits.linspace(-0.1, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&labels, &logits}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray labels('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
2019-06-06 14:21:15 +02:00
NDArray logits('c', {2}, {-0.2, 0.3});
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&labels, &logits}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, nd4j::DataType::INT64);
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
2019-06-06 14:21:15 +02:00
NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865,
0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865});
logits.linspace(-0.5, 0.1);
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&labels, &logits}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-31 14:30:49 +01:00
NDArray labels('c', {1,1}, std::vector<double>({0}), nd4j::DataType::INT64);
2019-06-06 14:21:15 +02:00
NDArray logits('c', {1,1,2}, {-0.3,0.2});
NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246});
2019-06-06 14:21:15 +02:00
nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op;
auto results = op.evaluate({&labels, &logits}, {}, {});
2019-06-06 14:21:15 +02:00
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *dLdp = results->at(0);
2019-06-06 14:21:15 +02:00
ASSERT_TRUE(dLdpExp.isSameShape(dLdp));
ASSERT_TRUE(dLdpExp.equalsTo(dLdp));
delete results;
}