2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 22.06.2018
//
# include "testlayers.h"
# include <ops/declarable/CustomOperations.h>
# include <NDArray.h>
# include <ops/ops.h>
# include <GradCheck.h>
# include <loops/random.h>
using namespace nd4j ;
class DeclarableOpsTests9 : public testing : : Test {
public :
DeclarableOpsTests9 ( ) {
printf ( " \n " ) ;
fflush ( stdout ) ;
}
} ;
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , reduceStDevBP_test3 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } ) ;
auto gradO1 = NDArrayFactory : : create < double > ( ' c ' , { 3 , 1 } , { 1. , 2. , 3. } ) ;
auto gradO2 = NDArrayFactory : : create < double > ( ' c ' , { 3 } , { 1. , 2. , 3. } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } , { - 0.335410 , - 0.111803 , 0.111803 , 0.335410 , - 0.670820 , - 0.223607 , 0.223607 , 0.670820 , - 1.006231 , - 0.335410 , 0.335410 , 1.006231 } ) ;
x . linspace ( 1 ) ;
nd4j : : ops : : reduce_stdev_bp op ;
auto result = op . execute ( { & x , & gradO2 } , { 0 , 0 } , { 1 } ) ;
2019-06-15 13:34:34 +02:00
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
2019-06-06 14:21:15 +02:00
auto output = result - > at ( 0 ) ;
// output->printIndexedBuffer();
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
result = op . execute ( { & x , & gradO1 } , { 1 , 0 } , { 1 } ) ;
2019-06-15 13:34:34 +02:00
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
2019-06-06 14:21:15 +02:00
output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , reduceStDevBP_test03 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } ) ;
auto gradO1 = NDArrayFactory : : create < double > ( ' c ' , { 3 , 1 } , { 1. , 2. , 3. } ) ;
auto gradO2 = NDArrayFactory : : create < double > ( ' c ' , { 3 } , { 1. , 2. , 3. } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } , { - 0.335410 , - 0.111803 , 0.111803 , 0.335410 , - 0.670820 , - 0.223607 , 0.223607 , 0.670820 , - 1.006231 , - 0.335410 , 0.335410 , 1.006231 } ) ;
auto axis = NDArrayFactory : : create < int > ( ' c ' , { 1 } , { 1 } ) ;
x . linspace ( 1 ) ;
nd4j : : ops : : reduce_stdev_bp op ;
auto result = op . execute ( { & x , & gradO2 , & axis } , { } , { } , { false , false } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
// output->printIndexedBuffer();
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
result = op . execute ( { & x , & gradO1 } , { 1 , 0 } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
/*
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , exponentialDistributionInv_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int N = 50000 ;
const double lambda = 2. ;
2019-06-15 13:34:34 +02:00
const double mean = 1. / lambda ;
2019-06-06 14:21:15 +02:00
const double std = mean ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { N } ) ;
double extraParams [ ] = { lambda } ;
Nd4jLong * buffer = new Nd4jLong [ N ] ;
NativeOps nativeOps ;
auto rng = ( nd4j : : random : : RandomBuffer * ) nativeOps . initRandom ( nullptr , 123 , N , ( Nd4jPointer ) buffer ) ;
if ( rng = = nullptr )
throw std : : runtime_error ( " DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed ! " ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
functions : : random : : RandomFunction < double > : : template execTransform < randomOps : : ExponentialDistributionInv < double > > ( rng , x . getBuffer ( ) , x . getShapeInfo ( ) , extraParams ) ;
const double actualMean = x . meanNumber ( ) . e < double > ( 0 ) ;
const double actualStd = x . varianceNumber ( variance : : SummaryStatsStandardDeviation , true ) . e < double > ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_NEAR ( mean , actualMean , 0.01 ) ;
2019-06-15 13:34:34 +02:00
ASSERT_NEAR ( std , actualStd , 0.01 ) ;
2019-06-06 14:21:15 +02:00
nativeOps . destroyRandom ( ( Nd4jPointer ) rng ) ;
delete [ ] buffer ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , exponentialDistributionInv_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int N = 50000 ;
const double lambda = 2. ;
2019-06-15 13:34:34 +02:00
const double mean = 1. / lambda ;
2019-06-06 14:21:15 +02:00
const double std = mean ;
double extraParams [ ] = { lambda } ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { N } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { N } ) ;
y . linspace ( 0. , 1. / N ) ; // [0, 1)
Nd4jLong * buffer = new Nd4jLong [ N ] ;
NativeOps nativeOps ;
auto rng = ( nd4j : : random : : RandomBuffer * ) nativeOps . initRandom ( nullptr , 123 , N , ( Nd4jPointer ) buffer ) ;
if ( rng = = nullptr )
throw std : : runtime_error ( " DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed ! " ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
functions : : random : : RandomFunction < double > : : template execTransform < randomOps : : ExponentialDistributionInv < double > > ( rng , y . getBuffer ( ) , y . getShapeInfo ( ) , x . getBuffer ( ) , x . getShapeInfo ( ) , extraParams ) ;
const double actualMean = x . meanNumber ( ) . e < double > ( 0 ) ;
const double actualStd = x . varianceNumber ( variance : : SummaryStatsStandardDeviation , true ) . e < double > ( 0 ) ;
ASSERT_NEAR ( mean , actualMean , 0.01 ) ;
2019-06-15 13:34:34 +02:00
ASSERT_NEAR ( std , actualStd , 0.01 ) ;
2019-06-06 14:21:15 +02:00
nativeOps . destroyRandom ( ( Nd4jPointer ) rng ) ;
delete [ ] buffer ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , exponentialDistribution_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int N = 50000 ;
const double lambda = 2. ;
2019-06-15 13:34:34 +02:00
const double mean = 1. / lambda ;
2019-06-06 14:21:15 +02:00
const double std = mean ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { N } ) ;
double extraParams [ ] = { lambda } ;
Nd4jLong * buffer = new Nd4jLong [ N ] ;
NativeOps nativeOps ;
auto rng = ( nd4j : : random : : RandomBuffer * ) nativeOps . initRandom ( nullptr , 123 , N , ( Nd4jPointer ) buffer ) ;
if ( rng = = nullptr )
throw std : : runtime_error ( " DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed ! " ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
functions : : random : : RandomFunction < double > : : template execTransform < randomOps : : ExponentialDistribution < double > > ( rng , x . getBuffer ( ) , x . getShapeInfo ( ) , extraParams ) ;
const double actualMean = x . meanNumber ( ) . e < double > ( 0 ) ;
const double actualStd = x . varianceNumber ( variance : : SummaryStatsStandardDeviation , true ) . e < double > ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_NEAR ( mean , actualMean , 0.01 ) ;
2019-06-15 13:34:34 +02:00
ASSERT_NEAR ( std , actualStd , 0.01 ) ;
2019-06-06 14:21:15 +02:00
nativeOps . destroyRandom ( ( Nd4jPointer ) rng ) ;
2019-06-15 13:34:34 +02:00
delete [ ] buffer ;
2019-06-06 14:21:15 +02:00
}
*/
//////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , exponentialDistribution_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int N = 50000 ;
const double lambda = 2. ;
2019-06-15 13:34:34 +02:00
const double mean = 1. / lambda ;
2019-06-06 14:21:15 +02:00
const double std = mean ;
double extraParams [ ] = { lambda } ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { N } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { N } ) ;
y . linspace ( - N / 2. ) ; // [-25000, 25000)
Nd4jLong * buffer = new Nd4jLong [ N ] ;
// Nd4jPointer extra[2];
# ifndef __CUDABLAS__
NativeOps nativeOps ;
nd4j : : random : : RandomBuffer * rng = ( nd4j : : random : : RandomBuffer * ) nativeOps . initRandom ( nullptr , 123 , N , ( Nd4jPointer ) buffer ) ;
if ( rng = = nullptr )
throw std : : runtime_error ( " DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed ! " ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
functions : : random : : RandomFunction < double > : : template execTransform < randomOps : : ExponentialDistribution < double > > ( rng , y . getBuffer ( ) , y . getShapeInfo ( ) , x . getBuffer ( ) , x . getShapeInfo ( ) , extraParams ) ;
nativeOps . destroyRandom ( ( Nd4jPointer ) rng ) ;
# endif
const double actualMean = x . meanNumber ( ) . e < double > ( 0 ) ;
const double actualStd = x . varianceNumber ( variance : : SummaryStatsStandardDeviation , true ) . e < double > ( 0 ) ;
ASSERT_NEAR ( mean , actualMean , 0.01 ) ;
2019-06-15 13:34:34 +02:00
ASSERT_NEAR ( std , actualStd , 0.01 ) ;
2019-06-06 14:21:15 +02:00
delete [ ] buffer ;
}
TEST_F ( DeclarableOpsTests9 , ScalarOpTest_MixedOrders_1 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 2 , 2 } , { 1.0 , 3.0 , 2.0 , 4.0 } ) ;
auto e = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 2.0 , 3.0 , 4.0 , 5.0 } ) ;
auto z = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 0.0 , 0.0 , 0.0 , 0.0 } ) ;
x . applyScalar ( scalar : : Add , 1.0 , & z ) ;
ASSERT_EQ ( e , z ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test1 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 , 4 } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 6 , 4 } , { 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f , 12.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 1.f , 2.f , 3.f , 4.f ,
13.f , 14.f , 15.f , 16.f , 17.f , 18.f , 19.f , 20.f , 21.f , 22.f , 23.f , 24.f , 9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 5.f , 6.f , 7.f , 8. } ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
x2 . linspace ( 1 ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test2 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 , 1 } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' c ' , { 1 , 2 , 1 } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 , 1 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 1 , 6 , 1 } , { 1.f , 2.f , 3.f , 1.f , 2.f , 1.f } ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
x2 . linspace ( 1 ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test3 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 3 } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' c ' , { 2 } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' c ' , { 1 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 6 } , { 1.f , 2.f , 3.f , 1.f , 2.f , 1.f } ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
x2 . linspace ( 1 ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test4 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 , 1 } , { 1.f } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 , 1 } , { 2.f } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 , 1 } , { 3.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 , 1 } , { 1.f , 2.f , 3.f } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test5 ) {
auto x0 = NDArrayFactory : : create < double > ( 1.f ) ;
auto x1 = NDArrayFactory : : create < double > ( ' c ' , { 1 } , { 2.f } ) ;
auto x2 = NDArrayFactory : : create < double > ( 3.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 3 } , { 1.f , 2.f , 3.f } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test6 ) {
auto x0 = NDArrayFactory : : create < double > ( 1.f ) ;
auto x1 = NDArrayFactory : : create < double > ( ' c ' , { 2 } , { 2.f , 20.f } ) ;
auto x2 = NDArrayFactory : : create < double > ( 3.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 4 } , { 1.f , 2.f , 20.f , 3.f } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test7 ) {
auto x0 = NDArrayFactory : : create < double > ( 1.f ) ;
auto x1 = NDArrayFactory : : create < double > ( 2.f ) ;
auto x2 = NDArrayFactory : : create < double > ( 3.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 3 } , { 1.f , 2.f , 3.f } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test8 ) {
auto x0 = NDArrayFactory : : create < double > ( 1.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 1 } , { 1.f } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test9 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 1 } , { 1.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 1 } , { 1.f } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test10 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 2 , 4 } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 6 , 4 } , { 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f , 12.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 1.f , 2.f , 3.f , 4.f ,
13.f , 14.f , 15.f , 16.f , 17.f , 18.f , 19.f , 20.f , 21.f , 22.f , 23.f , 24.f , 9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 5.f , 6.f , 7.f , 8.f } ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
x2 . linspace ( 1 ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test11 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 2 , 4 } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 1 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 6 , 4 } , { 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f , 12.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 1.f , 2.f , 3.f , 4.f ,
13.f , 14.f , 15.f , 16.f , 17.f , 18.f , 19.f , 20.f , 21.f , 22.f , 23.f , 24.f , 9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 5.f , 6.f , 7.f , 8.f } ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
x2 . linspace ( 1 ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test12 ) {
auto x0 = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 2 , 4 } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 1 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 6 , 4 } , { 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f , 12.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 1.f , 2.f , 3.f , 4.f ,
13.f , 14.f , 15.f , 16.f , 17.f , 18.f , 19.f , 20.f , 21.f , 22.f , 23.f , 24.f , 9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 5.f , 6.f , 7.f , 8.f } ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
x2 . linspace ( 1 ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test13 ) {
auto x0 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 3 , 4 } ) ;
auto x1 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 2 , 4 } ) ;
auto x2 = NDArrayFactory : : create < double > ( ' f ' , { 2 , 1 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 2 , 6 , 4 } , { 1.f , 13.f , 5.f , 17.f , 9.f , 21.f , 1.f , 9.f , 5.f , 13.f , 1.f , 5.f , 2.f , 14.f , 6.f , 18.f , 10.f , 22.f , 2.f , 10.f , 6.f , 14.f , 2.f , 6.f ,
3.f , 15.f , 7.f , 19.f , 11.f , 23.f , 3.f , 11.f , 7.f , 15.f , 3.f , 7.f , 4.f , 16.f , 8.f , 20.f , 12.f , 24.f , 4.f , 12.f , 8.f , 16.f , 4.f , 8.f } ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
x2 . linspace ( 1 ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { 1 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
TEST_F ( DeclarableOpsTests9 , concat_test14 ) {
NDArray x0 ( ' c ' , { 1 , 40 , 60 } , nd4j : : DataType : : DOUBLE ) ;
NDArray x1 ( ' c ' , { 1 , 40 , 60 } , nd4j : : DataType : : DOUBLE ) ;
x0 = 1. ;
x1 = 2. ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x0 , & x1 } , { } , { 0 } , { } ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
Nd4jLong numOfTads = ShapeUtils : : getNumOfSubArrs ( z - > getShapeInfo ( ) , { 0 } ) ;
ASSERT_TRUE ( 2 = = numOfTads ) ;
for ( int e = 0 ; e < numOfTads ; + + e ) {
NDArray tad = ( * z ) ( e , { 0 } ) ;
auto mean = tad . meanNumber ( ) . e < double > ( 0 ) ;
ASSERT_NEAR ( ( e + 1 ) * 1. , mean , 1e-5 ) ;
}
delete result ;
}
TEST_F ( DeclarableOpsTests9 , concat_test15 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 } , { 1 , 0 } ) ;
auto y = NDArrayFactory : : create < double > ( 3.0f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 3 } , { 1 , 0 , 3 } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x , & y } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete result ;
}
2019-06-15 13:34:34 +02:00
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , concat_test16 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 0 , 2 , 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 0 , 2 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 0 , 2 , 3 } ) ;
nd4j : : ops : : concat op ;
auto result = op . execute ( { & x , & y } , { } , { 0 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
delete result ;
}
2019-07-12 10:51:51 +02:00
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_bp_test1 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 4 , 9 } ) ;
auto gradIExp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } , { 0.78 , 0.84 , 0.9 , 1.32 , 1.38 , 1.44 } ) ;
gradO . linspace ( 0.01 , 0.01 ) ;
nd4j : : ops : : tile_bp op ;
auto results = op . execute ( { & input , & gradO } , { } , { 2 , 3 } ) ;
auto gradI = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( gradIExp . isSameShape ( gradI ) ) ;
ASSERT_TRUE ( gradIExp . equalsTo ( gradI ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_bp_test2 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 9 } ) ;
auto gradIExp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } , { 0.12 , 0.15 , 0.18 , 0.39 , 0.42 , 0.45 } ) ;
gradO . linspace ( 0.01 , 0.01 ) ;
nd4j : : ops : : tile_bp op ;
auto results = op . execute ( { & input , & gradO } , { } , { 1 , 3 } ) ;
auto gradI = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( gradIExp . isSameShape ( gradI ) ) ;
ASSERT_TRUE ( gradIExp . equalsTo ( gradI ) ) ;
delete results ;
}
2019-06-06 14:21:15 +02:00
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_bp_test3 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } ) ;
auto gradIExp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } , { 0.01 , 0.02 , 0.03 , 0.04 , 0.05 , 0.06 } ) ;
gradO . linspace ( 0.01 , 0.01 ) ;
nd4j : : ops : : tile_bp op ;
auto results = op . execute ( { & input , & gradO } , { } , { 1 , 1 } ) ;
auto gradI = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( gradIExp . isSameShape ( gradI ) ) ;
ASSERT_TRUE ( gradIExp . equalsTo ( gradI ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_bp_test4 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 6 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 12 } ) ;
auto gradIExp = NDArrayFactory : : create < double > ( ' c ' , { 6 } , { 0.08 , 0.1 , 0.12 , 0.14 , 0.16 , 0.18 } ) ;
gradO . linspace ( 0.01 , 0.01 ) ;
nd4j : : ops : : tile_bp op ;
auto results = op . execute ( { & input , & gradO } , { } , { 2 } ) ;
auto gradI = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( gradIExp . isSameShape ( gradI ) ) ;
ASSERT_TRUE ( gradIExp . equalsTo ( gradI ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_bp_test5 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 1 } , { 1. } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 1 } ) ;
auto gradIExp = NDArrayFactory : : create < double > ( ' c ' , { 1 } , { 0.01 } ) ;
gradO . linspace ( 0.01 , 0.01 ) ;
nd4j : : ops : : tile_bp op ;
auto results = op . execute ( { & input , & gradO } , { } , { 1 } ) ;
auto gradI = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( gradIExp . isSameShape ( gradI ) ) ;
ASSERT_TRUE ( gradIExp . equalsTo ( gradI ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_bp_test6 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 3 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 6 } ) ;
auto gradIExp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 3 } , { 0.51 , 0.57 , 0.63 , 1.59 , 1.65 , 1.71 } ) ;
gradO . linspace ( 0.01 , 0.01 ) ;
nd4j : : ops : : tile_bp op ;
auto results = op . execute ( { & input , & gradO } , { } , { 1 , 3 , 2 } ) ;
auto gradI = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( gradIExp . isSameShape ( gradI ) ) ;
ASSERT_TRUE ( gradIExp . equalsTo ( gradI ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_bp_test7 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 3 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto reps = NDArrayFactory : : create < int > ( ' c ' , { 1 , 3 } , { 1 , 3 , 2 } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 6 } ) ;
auto gradIExp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 3 } , { 0.51 , 0.57 , 0.63 , 1.59 , 1.65 , 1.71 } ) ;
gradO . linspace ( 0.01 , 0.01 ) ;
nd4j : : ops : : tile_bp op ;
auto results = op . execute ( { & input , & reps , & gradO } , { } , { } ) ;
auto gradI = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( gradIExp . isSameShape ( gradI ) ) ;
ASSERT_TRUE ( gradIExp . equalsTo ( gradI ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , tile_test1 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 1 , 6 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto reps = NDArrayFactory : : create < int > ( ' c ' , { 1 , 2 } , { 2 , 1 } ) ;
auto expOut = NDArrayFactory : : create < double > ( ' c ' , { 2 , 6 , } , { 1. , 2. , 3. , 4. , 5. , 6. , 1. , 2. , 3. , 4. , 5. , 6. } ) ;
expOut . printIndexedBuffer ( " expOut " ) ;
nd4j : : ops : : tile op ;
auto results = op . execute ( { & input , & reps } , { } , { } ) ;
auto out = results - > at ( 0 ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( expOut . isSameShape ( out ) ) ;
ASSERT_TRUE ( expOut . equalsTo ( out ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test1 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 4 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 3 , 3 } , { 35. , 79. , 123. , 40. , 92. , 144. , 45. , 105. , 165. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test2 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 4 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 3 , 3 } , { 35. , 79. , 123. , 40. , 92. , 144. , 45. , 105. , 165. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test3 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 4 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 3 , 3 } , { 35. , 79. , 123. , 40. , 92. , 144. , 45. , 105. , 165. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test4 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 4 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 3 , 3 } , { 35. , 79. , 123. , 40. , 92. , 144. , 45. , 105. , 165. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test5 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 4 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 3 , 3 } , { 83. , 94. , 105. , 94. , 107. , 120. , 105. , 120. , 135. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test6 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 3 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 3 , 3 } , { 35. , 40. , 45. , 79. , 92. , 105. , 123. , 144. , 165. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test7 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 5 , 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 5 , 3 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 5 , 3 , 3 } , { 3. , 84.6 , 281.4 , 593.4 , 1020.6 , 7. , 107.8 , 323.8 , 655. , 1101.4 , 11. , 131. , 366.2 , 716.6 , 1182.2 ,
7. , 107.8 , 323.8 , 655. , 1101.4 , 17.4 , 137.4 , 372.6 , 723. , 1188.6 , 27.8 , 167. , 421.4 , 791. , 1275.8 ,
11. , 131. , 366.2 , 716.6 , 1182.2 , 27.8 , 167. , 421.4 , 791. , 1275.8 , 44.6 , 203. , 476.6 , 865.4 , 1369.4 , } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 0 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test8 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 5 , 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 2 , 5 , 3 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 2 , 5 , 3 , 3 } , { 3. , 1563. , 84.6 , 2220.6 , 281.4 , 2993.4 , 593.4 , 3881.4 , 1020.6 , 4884.6 , 7. , 1663. , 107.8 , 2339.8 , 323.8 , 3131.8 , 655. , 4039. , 1101.4 , 5061.4 ,
11. , 1763. , 131. , 2459. , 366.2 , 3270.2 , 716.6 , 4196.6 , 1182.2 , 5238.2 , 7. , 1663. , 107.8 , 2339.8 , 323.8 , 3131.8 , 655. , 4039. , 1101.4 , 5061.4 ,
17.4 , 1769.4 , 137.4 , 2465.4 , 372.6 , 3276.6 , 723. , 4203. , 1188.6 , 5244.6 , 27.8 , 1875.8 , 167. , 2591. , 421.4 , 3421.4 , 791. , 4367. , 1275.8 , 5427.8 ,
11. , 1763. , 131. , 2459. , 366.2 , 3270.2 , 716.6 , 4196.6 , 1182.2 , 5238.2 , 27.8 , 1875.8 , 167. , 2591. , 421.4 , 3421.4 , 791. , 4367. , 1275.8 , 5427.8 ,
44.6 , 1988.6 , 203. , 2723. , 476.6 , 3572.6 , 865.4 , 4537.4 , 1369.4 , 5617.4 } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 0 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test9 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 5 , 4 , 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 2 , 5 , 3 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 2 , 5 , 3 , 3 } , { 7. , 1639. , 103. , 2311. , 314.2 , 3098.2 , 640.6 , 4000.6 , 1082.2 , 5018.2 , 8. , 1664. , 108.8 , 2340.8 , 324.8 , 3132.8 , 656. , 4040. , 1102.4 , 5062.4 ,
9. , 1689. , 114.6 , 2370.6 , 335.4 , 3167.4 , 671.4 , 4079.4 , 1122.6 , 5106.6 , 15.8 , 1743.8 , 131. , 2435. , 361.4 , 3241.4 , 707. , 4163. , 1167.8 , 5199.8 ,
18.4 , 1770.4 , 138.4 , 2466.4 , 373.6 , 3277.6 , 724. , 4204. , 1189.6 , 5245.6 , 21. , 1797. , 145.8 , 2497.8 , 385.8 , 3313.8 , 741. , 4245. , 1211.4 , 5291.4 ,
24.6 , 1848.6 , 159. , 2559. , 408.6 , 3384.6 , 773.4 , 4325.4 , 1253.4 , 5381.4 , 28.8 , 1876.8 , 168. , 2592. , 422.4 , 3422.4 , 792. , 4368. , 1276.8 , 5428.8 ,
33. , 1905. , 177. , 2625. , 436.2 , 3460.2 , 810.6 , 4410.6 , 1300.2 , 5476.2 } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , TestDropout_BP_1 ) {
NDArray x ( ' c ' , { 2 , 2 , 2 } , { 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f } ) ;
NDArray errs ( ' c ' , { 2 , 2 , 2 } , { 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f } ) ;
NDArray shape ( ' c ' , { 2 } , { 2 , 2 } ) ;
nd4j : : ops : : dropout_bp op ;
auto ress = op . execute ( { & x , & errs , & shape } , { 0.2f } , { 113 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ress - > status ( ) ) ;
//ress->at(0)->printIndexedBuffer("Result is ");
//x.printIndexedBuffer("Input is");
ASSERT_FALSE ( ress - > at ( 0 ) - > equalsTo ( errs ) ) ;
delete ress ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , TestDropout_1 ) {
NDArray x ( ' c ' , { 10 , 10 } , nd4j : : DataType : : FLOAT32 ) ;
// NDArray<float> errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
//NDArray<float> shape({2.f, 2.f});
nd4j : : ops : : dropout op ;
x . linspace ( 1 ) ;
auto ress = op . execute ( { & x } , { 0.2f } , { 113 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ress - > status ( ) ) ;
NDArray * res = ress - > at ( 0 ) ; //->printIndexedBuffer("Result is ");
//x.printIndexedBuffer("Input is");
//res->printIndexedBuffer("Result for Dropout_1");
auto countZero = res - > reduceNumber ( reduce : : CountZero ) ;
ASSERT_NEAR ( countZero . e < Nd4jLong > ( 0 ) , 80 , 5 ) ;
auto ress2 = op . execute ( { & x } , { 0.2f } , { 113 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ress2 - > status ( ) ) ;
NDArray * res2 = ress2 - > at ( 0 ) ;
countZero = res - > reduceNumber ( reduce : : CountZero ) ;
ASSERT_NEAR ( countZero . e < Nd4jLong > ( 0 ) , 80 , 5 ) ;
//res2->printIndexedBuffer("Result for Dropout_2");
ASSERT_TRUE ( res - > equalsTo ( res2 ) ) ;
//res->printIndexedBuffer("FF dropout");
//res2->printIndexedBuffer("BP dropout");
delete ress ;
delete ress2 ;
}
TEST_F ( DeclarableOpsTests9 , Test_DropoutInverted_01 ) {
NDArray x0 ( ' c ' , { 10 , 10 } , nd4j : : DataType : : FLOAT32 ) ;
NDArray x1 ( ' c ' , { 10 , 10 } , nd4j : : DataType : : FLOAT32 ) ;
x0 . linspace ( 1 ) ;
x1 . linspace ( 1 ) ;
/*
NativeOps nativeOps ;
float prob [ ] = { 0.5f } ;
Nd4jLong * _bufferA = new Nd4jLong [ 100000 ] ;
long _seed = 119L ;
auto _rngA = ( nd4j : : random : : RandomBuffer * ) nativeOps . initRandom ( nullptr , _seed , 100000 , ( Nd4jPointer ) _bufferA ) ;
x0 . applyTransform ( random : : DropOutInverted , & x0 , prob ) ;
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
// x0.printIndexedBuffer("01Result1");
int count = 0 ;
for ( int e = 0 ; e < x0 . lengthOf ( ) ; e + + )
if ( x0 . e < float > ( e ) ! = 0.f )
count + + ;
// nd4j_printf("\nX0 count %i\n", count);
// ASSERT_TRUE(x0.equalsTo(&x1));
// this check is required to ensure we're calling wrong signature
// ASSERT_FALSE(x0.equalsTo(nexp0));
// ASSERT_FALSE(x0.equalsTo(nexp1));
// ASSERT_FALSE(x0.equalsTo(nexp2));
nativeOps . destroyRandom ( _rngA ) ;
delete [ ] _bufferA ;
*/
nd4j : : ops : : dropout op ;
auto ress = op . execute ( { & x1 } , { 0.5f } , { 119 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ress - > status ( ) ) ;
//ress->at(0)->printIndexedBuffer("01Dropout result is ");
auto count = ress - > at ( 0 ) - > reduceNumber ( reduce : : CountNonZero ) ;
// nd4j_printf("\n01Dropout count %i\n\n", count);
nd4j : : ops : : dropout_bp op2 ;
//NDArray<float> exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f});
//02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000]
auto ressX = op2 . execute ( { & x1 , & x1 } , { 0.5f } , { 119 } ) ; // , false, nd4j::DataType::FLOAT32); // skipped due given by default
//x0.printIndexedBuffer("X0");
//x1.printIndexedBuffer("X1");
ASSERT_EQ ( ND4J_STATUS_OK , ressX - > status ( ) ) ;
auto ressY = op2 . execute ( { & x1 , & x0 } , { 0.5f } , { 119 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ressY - > status ( ) ) ;
//ressY->at(0)->printIndexedBuffer("BP");
//ress->at(0)->printIndexedBuffer("FF");
bool ret = true ;
for ( int e = 0 ; e < ress - > at ( 0 ) - > lengthOf ( ) ; e + + ) {
if ( ress - > at ( 0 ) - > e < float > ( e ) = = 0.f )
if ( ressX - > at ( 0 ) - > e < float > ( e ) ! = ress - > at ( 0 ) - > e < float > ( e ) ) {
ret = false ;
break ;
}
}
ASSERT_TRUE ( ret ) ;
// ASSERT_FALSE(ressX->at(0)->equalsTo(ressY->at(0)));
//ressX->at(0)->printIndexedBuffer("02Dropout result is ");
/* float countZero = ressX->at(0)->template reduceNumber<simdOps::CountZero<float>>();
ASSERT_NEAR ( countZero , 50.f , 5.f ) ;
countZero = ress - > at ( 0 ) - > template reduceNumber < simdOps : : CountZero < float > > ( ) ;
ASSERT_NEAR ( countZero , 50.f , 5.f ) ;
countZero = ressY - > at ( 0 ) - > template reduceNumber < simdOps : : CountZero < float > > ( ) ;
ASSERT_NEAR ( countZero , 50.f , 5.f ) ;
*/
// ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
delete ressX ;
delete ressY ;
delete ress ;
}
TEST_F ( DeclarableOpsTests9 , Test_Dropout_BP_2 ) {
NDArray x ( ' c ' , { 10 , 10 } , nd4j : : DataType : : FLOAT32 ) ;
x . linspace ( 1 ) ;
nd4j : : ops : : dropout op ;
auto ress = op . execute ( { & x } , { 0.5f } , { 119 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ress - > status ( ) ) ;
// ress->at(0)->printIndexedBuffer("01Dropout result is ");
nd4j : : ops : : dropout_bp op2 ;
auto ressX = op2 . execute ( { & x , & x } , { 0.5f } , { 119 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ressX - > status ( ) ) ;
auto ressY = op2 . execute ( { & x , & x } , { 0.5f } , { 119 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ressY - > status ( ) ) ;
//ress->at(0)->printIndexedBuffer("FF Dropout result is ");
//ressY->at(0)->printIndexedBuffer("BP Dropout result is ");
auto countZero = ress - > at ( 0 ) - > reduceNumber ( reduce : : CountZero ) ;
ASSERT_NEAR ( countZero . e < float > ( 0 ) , 50.f , 10.f ) ;
countZero = ressX - > at ( 0 ) - > reduceNumber ( reduce : : CountZero ) ;
//nd4j_printf("X zero count is %f\n", countZero);
ASSERT_NEAR ( countZero . e < float > ( 0 ) , 50.f , 10.f ) ;
countZero = ressY - > at ( 0 ) - > reduceNumber ( reduce : : CountZero ) ;
//nd4j_printf("Y zero count is %f\n", countZero);
ASSERT_NEAR ( countZero . e < float > ( 0 ) , 50.f , 10.f ) ;
// ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
ASSERT_TRUE ( ressX - > at ( 0 ) - > equalsTo ( ressY - > at ( 0 ) ) ) ;
delete ressX ;
delete ressY ;
delete ress ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , Test_AlphaDropout_BP_1 ) {
NDArray x ( ' c ' , { 10 , 10 } , nd4j : : DataType : : FLOAT32 ) ;
NDArray eps ( ' c ' , { 10 , 10 } , nd4j : : DataType : : FLOAT32 ) ;
x . linspace ( 1 ) ;
eps . linspace ( 1 ) ;
nd4j : : ops : : alpha_dropout_bp op ;
auto ress = op . execute ( { & x , & eps } , { 0.5f , 0.5f , 1.5f , 1.6f } , { 119 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ress - > status ( ) ) ;
NDArray * res = ress - > at ( 0 ) ;
auto ress2 = op . execute ( { & x , & eps } , { 0.5f , 0.5f , 1.5f , 1.6f } , { 119 } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , ress2 - > status ( ) ) ;
NDArray * res2 = ress2 - > at ( 0 ) ;
//res->printIndexedBuffer("Result1AlphaBP1");
//res2->printIndexedBuffer("Result1AlphaBP2");
ASSERT_TRUE ( res2 - > equalsTo ( res ) ) ;
delete ress ;
delete ress2 ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test10 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 1 , 4 , 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 1 , 3 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 1 , 3 , 3 } , { 35. , 40. , 45. , 79. , 92. , 105. , 123. , 144. , 165. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test11 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 1 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 1 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 } , { 15 } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 } ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
auto z = results - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test12 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 1 , 4 , 1 } ) ;
auto y = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 , 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 , 1 } , { 15 } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 } ) ;
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
auto z = results - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test13 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 5 , 2 } , { 23. , 26. , 29. , 32. , 35. , 50. , 57.5 , 65. , 72.5 , 80. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 0 , 0 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test14 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 3 , 2 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 5 , 2 } , { 37. , 41.5 , 46. , 50.5 , 55. , 46. , 52. , 58. , 64. , 70. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 0 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test15 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 3 , 2 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 5 , 2 } , { 37. , 41.5 , 46. , 50.5 , 55. , 46. , 52. , 58. , 64. , 70. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.5 , 0.5 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 0 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test16 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 , 3 , 5 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 , 4 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 2 , 2 , 4 , 5 } , { 4.6 , 281.8 , 89.2 , 582.4 , 10. , 314.2 , 108.1 , 628.3 , 15.4 , 346.6 , 127. , 674.2 , 20.8 , 379. , 145.9 , 720.1 , 5.2 , 289.6 , 93.4 , 593.8 ,
11.5 , 322.9 , 113.2 , 640.6 , 17.8 , 356.2 , 133. , 687.4 , 24.1 , 389.5 , 152.8 , 734.2 , 5.8 , 297.4 , 97.6 , 605.2 , 13. , 331.6 , 118.3 , 652.9 ,
20.2 , 365.8 , 139. , 700.6 , 27.4 , 400. , 159.7 , 748.3 , 6.4 , 305.2 , 101.8 , 616.6 , 14.5 , 340.3 , 123.4 , 665.2 , 22.6 , 375.4 , 145. , 713.8 ,
30.7 , 410.5 , 166.6 , 762.4 , 7. , 313. , 106. , 628. , 16. , 349. , 128.5 , 677.5 , 25. , 385. , 151. , 727. , 34. , 421. , 173.5 , 776.5 } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test17 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 4 , 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 3 } , { 7. , 8. , 9. } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 0 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test18 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 3 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 4 , 3 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 4 } , { 1.4 , 3.2 , 5. , 6.8 } ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 0 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test19 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 } , { 0.2 } ) ;
x . linspace ( 2. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test20 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 } , { 0.2 } ) ;
x . linspace ( 2. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test21 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 1 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 1 } , { 0.2 } ) ;
x . linspace ( 2. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test22 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 1 , 1 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 1 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' f ' , { 1 } , { 0.2 } ) ;
x . linspace ( 2. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test23 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( 3. ) ;
x . linspace ( 1. ) ;
y . linspace ( 0.1 , 0.1 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , matmul_test24 ) {
auto x = NDArrayFactory : : create < double > ( ' f ' , { 1 } , { 2. } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 1 } , { 3. } ) ;
auto exp = NDArrayFactory : : create < double > ( 6. ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : matmul op ;
auto results = op . execute ( { & x , & y } , { } , { 1 , 1 } ) ;
auto z = results - > at ( 0 ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
ASSERT_EQ ( Status : : OK ( ) , results - > status ( ) ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete results ;
}
TEST_F ( DeclarableOpsTests9 , test_range_int_1 ) {
auto x0 = NDArrayFactory : : create < int > ( 0 ) ;
auto x1 = NDArrayFactory : : create < int > ( 2 ) ;
auto x2 = NDArrayFactory : : create < int > ( 1 ) ;
nd4j : : ops : : range op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { } ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
z - > printIndexedBuffer ( " z " ) ;
delete result ;
}
TEST_F ( DeclarableOpsTests9 , test_range_empty_1 ) {
auto x0 = NDArrayFactory : : create < int > ( 0 ) ;
auto x1 = NDArrayFactory : : create < int > ( 0 ) ;
auto x2 = NDArrayFactory : : create < int > ( 1 ) ;
nd4j : : ops : : range op ;
auto result = op . execute ( { & x0 , & x1 , & x2 } , { } , { } ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( z - > isEmpty ( ) ) ;
delete result ;
}
TEST_F ( DeclarableOpsTests9 , test_broadcast_bool_1 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 , 2 , 4 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 1 , 2 , 4 , 4 } ) ;
auto z = NDArrayFactory : : create < bool > ( ' c ' , { 1 , 3 , 2 , 4 , 4 } ) ;
std : : vector < int > dims = { 0 , 2 , 3 , 4 } ;
x . applyBroadcast ( broadcast : : LessThan , dims , & y , & z , nullptr ) ;
}
TEST_F ( DeclarableOpsTests9 , test_broadcast_bool_2 ) {
auto orig = NDArrayFactory : : create < double > ( ' c ' , { 1 , 7 , 4 , 4 } ) ;
std : : vector < Nd4jLong > list = { 0 , 0 , 0 , 2 , 0 , 0 , 0 , 0 } ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 , 2 , 4 , 4 } ) ;
auto y = orig ( list , true ) ;
auto z = NDArrayFactory : : create < bool > ( ' c ' , { 1 , 3 , 2 , 4 , 4 } ) ;
std : : vector < int > dims = { 0 , 2 , 3 , 4 } ;
x . applyBroadcast ( broadcast : : LessThan , dims , & y , & z , nullptr ) ;
}
TEST_F ( DeclarableOpsTests9 , test_unstack_1 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 5 , 5 } ) ;
x . linspace ( 1.0 ) ;
nd4j : : ops : : unstack op ;
auto result = op . execute ( { & x } , { } , { 0 } ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
ASSERT_EQ ( 5 , result - > size ( ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , test_unstack_SGO_1 ) {
auto x = NDArrayFactory : : create < double > ( { 1 , 2 , 3 , 4 , 5 } ) ;
x . linspace ( 1.0 ) ;
auto z1 = NDArrayFactory : : create < double > ( 1 ) ;
auto z2 = NDArrayFactory : : create < double > ( 2 ) ;
auto z3 = NDArrayFactory : : create < double > ( 3 ) ;
auto z4 = NDArrayFactory : : create < double > ( 4 ) ;
auto z5 = NDArrayFactory : : create < double > ( 5 ) ;
std : : vector < NDArray * > z ( { & z1 , & z2 , & z3 , & z4 , & z5 } ) ;
nd4j : : ops : : unstack op ;
auto result = op . execute ( { & x } , { } , { 0 } ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
ASSERT_EQ ( 5 , result - > size ( ) ) ;
for ( size_t i = 0 ; i < result - > size ( ) ; i + + ) {
ASSERT_TRUE ( result - > at ( i ) - > isSameShape ( z [ i ] ) ) ;
ASSERT_TRUE ( result - > at ( i ) - > equalsTo ( z [ i ] ) ) ;
}
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , clipbynorm_test12 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int bS = 5 ;
const int nOut = 4 ;
const int axis = 0 ;
const double clip = 2. ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } , { 0.412 , 0.184 , 0.961 , 0.897 , 0.173 , 0.931 , 0.736 , 0.540 , 0.953 , 0.278 , 0.573 , 0.787 , 0.320 , 0.776 , 0.338 , 0.311 , 0.835 , 0.909 , 0.890 , 0.290 } ) ; // uniform random in range [0,1]
auto colVect = NDArrayFactory : : create < double > ( ' c ' , { bS , 1 } , { 0.9 , 0.95 , 1.00 , 1.05 , 1.1 } ) ;
auto expect = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } ) ;
auto norm2 = x . reduceAlongDims ( reduce : : Norm2 , { axis } , true ) ; // norm2 has shape [1, nOut]
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto y = ( ( x / norm2 ) * clip ) * colVect ;
auto temp = ( x / norm2 ) * clip ;
for ( int j = 0 ; j < nOut ; + + j ) {
auto yCol = y ( { 0 , 0 , j , j + 1 } ) ;
const double norm2Col = yCol . reduceNumber ( reduce : : Norm2 ) . e < double > ( 0 ) ;
2019-06-15 13:34:34 +02:00
if ( norm2Col < = clip )
2019-06-06 14:21:15 +02:00
expect ( { 0 , 0 , j , j + 1 } ) . assign ( yCol ) ;
2019-06-15 13:34:34 +02:00
else
2019-06-06 14:21:15 +02:00
expect ( { 0 , 0 , j , j + 1 } ) . assign ( yCol * ( clip / norm2Col ) ) ;
}
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : clipbynorm op ;
auto result = op . execute ( { & y } , { clip } , { axis } , { } , false , nd4j : : DataType : : DOUBLE ) ;
2019-06-15 13:34:34 +02:00
auto outFF = result - > at ( 0 ) ;
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( expect . isSameShape ( outFF ) ) ;
ASSERT_TRUE ( expect . equalsTo ( outFF ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , clipbynorm_bp_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int bS = 2 ;
const int nOut = 3 ;
const int axis = 0 ;
const double clip = 0.7 ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } , { 0.412 , 0.184 , 0.961 , 0.173 , 0.736 , 0.540 } ) ; // uniform random in range [0,1]
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } ) ;
const OpArgsHolder argsHolderFF ( { & x } , { clip } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { clip } , { } ) ;
nd4j : : ops : : clipbynorm opFF ;
nd4j : : ops : : clipbynorm_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , clipbynorm_bp_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int bS = 2 ;
const int nOut = 3 ;
const int axis = 0 ;
const double clip = 0.7 ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } , { 0.412 , 0.184 , 0.961 , 0.173 , 0.736 , 0.540 } ) ; // uniform random in range [0,1]
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } ) ;
const OpArgsHolder argsHolderFF ( { & x } , { clip } , { axis } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { clip } , { axis } ) ;
nd4j : : ops : : clipbynorm opFF ;
nd4j : : ops : : clipbynorm_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , clipbynorm_bp_test3 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const int bS = 2 ;
const int nOut = 3 ;
const int axis = 1 ;
const double clip = 1. ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } , { 0.412 , 0.184 , 0.961 , 0.173 , 0.736 , 0.540 } ) ; // uniform random in range [0,1]
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { bS , nOut } ) ;
const OpArgsHolder argsHolderFF ( { & x } , { clip } , { axis } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { clip } , { axis } ) ;
nd4j : : ops : : clipbynorm opFF ;
nd4j : : ops : : clipbynorm_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_1 ) {
auto inputC = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. , 10. , 11. , 12. , 13. , 14. , 15. } ) ;
auto axis = NDArrayFactory : : create < Nd4jLong > ( 1 ) ;
auto expFF = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 1. , 2. , 6. , 24. , 120. , 6. , 42. , 336. , 3024. , 30240. , 11. , 132. , 1716. , 24024. , 360360. } ) ;
auto expTF = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 1 , 1 , 2 , 6 , 24 , 1 , 6 , 42 , 336 , 3024 , 1 , 11 , 132 , 1716 , 24024 } ) ;
auto expFT = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 120 , 120 , 60 , 20 , 5 , 30240 , 5040 , 720 , 90 , 10 , 360360 , 32760 , 2730 , 210 , 15 } ) ; //+++
auto expTT = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 120 , 60 , 20 , 5 , 1 , 5040 , 720 , 90 , 10 , 1 , 32760 , 2730 , 210 , 15 , 1 } ) ;
int exclusive , reverse ;
//************************************//
exclusive = 0 ; reverse = 0 ;
nd4j : : ops : : cumprod op ;
auto result = op . execute ( { & inputC , & axis } , { } , { exclusive , reverse } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( expFF . equalsTo ( z ) ) ;
delete result ;
//************************************//
exclusive = 1 ; reverse = 0 ;
result = op . execute ( { & inputC , & axis } , { } , { exclusive , reverse } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
z = result - > at ( 0 ) ;
ASSERT_TRUE ( expTF . equalsTo ( z ) ) ;
delete result ;
//************************************//
exclusive = 0 ; reverse = 1 ;
result = op . execute ( { & inputC , & axis } , { } , { exclusive , reverse } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
z = result - > at ( 0 ) ;
ASSERT_TRUE ( expFT . equalsTo ( z ) ) ;
delete result ;
//************************************//
exclusive = 1 ; reverse = 1 ;
result = op . execute ( { & inputC , & axis } , { } , { exclusive , reverse } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
z = result - > at ( 0 ) ;
ASSERT_TRUE ( expTT . equalsTo ( z ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_2 ) {
NDArray x ( ' c ' , { 2 , 1500 } , nd4j : : DataType : : FLOAT32 ) ;
NDArray x0 = x ( 0 , { 0 } ) ;
NDArray x1 = x ( 1 , { 0 } ) ;
x0 . linspace ( 1 , 0.1 ) ;
x1 . linspace ( 1 , 0.1 ) ;
NDArray exp ( ' c ' , { 2 , 1500 } , nd4j : : DataType : : FLOAT32 ) ;
NDArray exp0 = exp ( 0 , { 0 } ) ;
NDArray exp1 = exp ( 1 , { 0 } ) ;
exp0 . p < float > ( 0 , 1. ) ;
exp1 . p < float > ( 0 , 1. ) ;
for ( int i = 1 ; i < 1500 ; + + i ) {
const auto prev = exp0 . e < float > ( i - 1 ) ;
exp0 . p < float > ( i , prev * x0 . e < float > ( i ) ) ;
exp1 . p < float > ( i , prev * x1 . e < float > ( i ) ) ;
}
nd4j : : ops : : cumprod op ;
auto result = op . execute ( { & x } , { } , { 0 , 0 , 1 } ) ;
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete result ;
}
2019-06-06 14:21:15 +02:00
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_bp_check_1 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
x . linspace ( 1 ) ;
const OpArgsHolder argsHolderFF ( { & x } , { } , { 0 , 0 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { } , { 0 , 0 } ) ;
nd4j : : ops : : cumprod opFF ;
nd4j : : ops : : cumprod_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 } , { 1 , 1 } , GradCheck : : MEAN ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_bp_check_2 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
x . linspace ( 1 ) ;
const OpArgsHolder argsHolderFF ( { & x } , { } , { 1 , 1 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { } , { 1 , 1 } ) ;
nd4j : : ops : : cumprod opFF ;
nd4j : : ops : : cumprod_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 } , { 1 , 1 } , GradCheck : : MEAN ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_bp_check_3 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
x . linspace ( 1 ) ;
const OpArgsHolder argsHolderFF ( { & x } , { } , { 1 , 0 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { } , { 1 , 0 } ) ;
nd4j : : ops : : cumprod opFF ;
nd4j : : ops : : cumprod_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 } , { 1 , 1 } , GradCheck : : MEAN ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_bp_check_4 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
x . linspace ( 1 ) ;
const OpArgsHolder argsHolderFF ( { & x } , { } , { 0 , 1 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { } , { 0 , 1 } ) ;
nd4j : : ops : : cumprod opFF ;
nd4j : : ops : : cumprod_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 } , { 1 , 1 } , GradCheck : : MEAN ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
//////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumsum_bp_check_2 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 4 , 4 } ) ;
x . linspace ( 1 ) ;
const OpArgsHolder argsHolderFF ( { & x } , { } , { 1 , 1 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & gradO } , { } , { 1 , 1 } ) ;
nd4j : : ops : : cumsum opFF ;
nd4j : : ops : : cumsum_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 } , { 1 , 1 } , GradCheck : : MEAN ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto inputC = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. , 10. , 11. , 12. , 13. , 14. , 15. } ) ;
auto axis = NDArrayFactory : : create < double > ( 1. ) ;
auto expFF = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 1. , 2. , 6. , 24. , 120. , 6. , 42. , 336. , 3024. , 30240. , 11. , 132. , 1716. , 24024. , 360360. } ) ;
auto expTF = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 1 , 1 , 2 , 6 , 24 , 1 , 6 , 42 , 336 , 3024 , 1 , 11 , 132 , 1716 , 24024 } ) ;
auto expFT = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 120 , 120 , 60 , 20 , 5 , 30240 , 5040 , 720 , 90 , 10 , 360360 , 32760 , 2730 , 210 , 15 } ) ; //+++
auto expTT = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { 120 , 60 , 20 , 5 , 1 , 5040 , 720 , 90 , 10 , 1 , 32760 , 2730 , 210 , 15 , 1 } ) ;
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } ) ;
2019-06-15 13:34:34 +02:00
int exclusive , reverse ;
2019-06-06 14:21:15 +02:00
//************************************//
exclusive = 0 ; reverse = 0 ;
const OpArgsHolder argsHolderFF ( { & inputC , & axis } , { } , { exclusive , reverse } ) ;
const OpArgsHolder argsHolderBP ( { & inputC , & axis , & gradO } , { } , { exclusive , reverse } ) ;
nd4j : : ops : : cumprod opFF ;
nd4j : : ops : : cumprod_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 } , { 1 , 1 } , GradCheck : : MEAN ) ;
ASSERT_TRUE ( isGradCorrect ) ;
//************************************//
/* exclusive = 1; reverse = 0;
result = op . execute ( { & inputC , & axis } , { } , { exclusive , reverse } ) ;
2019-06-15 13:34:34 +02:00
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
z = result - > at ( 0 ) ;
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( expTF . equalsTo ( z ) ) ;
delete result ;
*/
//************************************//
/* exclusive = 0; reverse = 1;
result = op . execute ( { & inputC , & axis } , { } , { exclusive , reverse } ) ;
2019-06-15 13:34:34 +02:00
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
z = result - > at ( 0 ) ;
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( expFT . equalsTo ( z ) ) ;
delete result ;
*/
//************************************//
/* exclusive = 1; reverse = 1;
result = op . execute ( { & inputC , & axis } , { } , { exclusive , reverse } ) ;
2019-06-15 13:34:34 +02:00
ASSERT_EQ ( Status : : OK ( ) , result - > status ( ) ) ;
z = result - > at ( 0 ) ;
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( expTT . equalsTo ( z ) ) ;
2019-06-15 13:34:34 +02:00
delete result ;
*/
2019-06-06 14:21:15 +02:00
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , cumprod_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto inputC = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } ) ;
auto axis = NDArrayFactory : : create < double > ( 1. ) ;
// auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.});
// auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024});
// auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++
// auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
auto gradO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } ) ;
2019-06-15 13:34:34 +02:00
int exclusive , reverse ;
2019-06-06 14:21:15 +02:00
//************************************//
exclusive = 0 ; reverse = 0 ;
inputC . linspace ( 1 ) ;
const OpArgsHolder argsHolderFF ( { & inputC , & axis } , { } , { exclusive , reverse } ) ;
const OpArgsHolder argsHolderBP ( { & inputC , & axis , & gradO } , { } , { exclusive , reverse } ) ;
nd4j : : ops : : cumprod opFF ;
nd4j : : ops : : cumprod_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 , 1 , 1 } , { 1 , 1 } , GradCheck : : MEAN ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } , { - 0.6f , - 0.5f , - 0.4f , - 0.3f , - 0.2f , - 0.1f , 0.f , 0.1f , 0.2f , 0.3f , 0.4f , 0.5f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 7.2f , 5.5f , 4.f , 2.7f , 1.6f , 0.7f , 0.f , - 0.5f , - 0.8f , - 0.9f , - 0.8f , - 0.5f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 3 } , { - 0.6f , 2.f , 4.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 7.2f , 6.6f , 6.f , 5.4f , - 16.f , - 14.f , - 12.f , - 10.f , - 16.f , - 12.f , - 8.f , - 4.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 0 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test3 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 3 , 1 } , { - 0.6f , 2.f , 4.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 7.2f , 6.6f , 6.f , 5.4f , - 16.f , - 14.f , - 12.f , - 10.f , - 16.f , - 12.f , - 8.f , - 4.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 0 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test4 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 } , { - 0.6f , 2.f , 4.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 7.2f , 6.6f , 6.f , 5.4f , - 16.f , - 14.f , - 12.f , - 10.f , - 16.f , - 12.f , - 8.f , - 4.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 0 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test5 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 4 } , { - 0.6f , 2.f , 4.f , - 1.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 7.2f , - 22.f , - 40.f , 9.f , 4.8f , - 14.f , - 24.f , 5.f , 2.4f , - 6.f , - 8.f , 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 1 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test6 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 , 1 } , { - 2. } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 24.f , 22.f , 20.f , 18.f , 16.f , 14.f , 12.f , 10.f , 8.f , 6.f , 4.f , 2.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 1 , 0 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test7 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( - 2.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 24.f , 22.f , 20.f , 18.f , 16.f , 14.f , 12.f , 10.f , 8.f , 6.f , 4.f , 2.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 1 , 0 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test8 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( - 2.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 24.f , 22.f , 20.f , 18.f , 16.f , 14.f , 12.f , 10.f , 8.f , 6.f , 4.f , 2.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 1 , 0 , 1 , 0 , 1 , 0 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test9 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 4 } , { - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( - 2.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 4 } , { 8.f , 6.f , 4.f , 2.f , 0.f , 1.f , 2.f , 3.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 0 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test10 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 4 } , { - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f } ) ;
auto alpha = NDArrayFactory : : create < double > ( - 2.f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 4 } , { 8.f , 6.f , 4.f , 2.f , 0.f , 1.f , 2.f , 3.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 1 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test11 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } ) ;
x . linspace ( - 50. ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 4 } , { 0.f , - 0.5f , 0.5f , - 1.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } , { 0.f , 0.f , 0.f , 0.f , 0.f , 22.5f , 22.f , 21.5f , 21.f , 20.5f , - 20.f , - 19.5f , - 19.f , - 18.5f , - 18.f , 35.f , 34.f , 33.f ,
2019-06-15 13:34:34 +02:00
32.f , 31.f , 0.f , 0.f , 0.f , 0.f , 0.f , 12.5f , 12.f , 11.5f , 11.f , 10.5f , - 10.f , - 9.5f , - 9.f , - 8.5f , - 8.f , 15.f ,
14.f , 13.f , 12.f , 11.f , 0.f , 0.f , 0.f , 0.f , 0.f , 2.5f , 2.f , 1.5f , 1.f , 0.5f , 0.f , 1.f , 2.f , 3.f , 4.f ,
5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 17.f , 18.f , 19.f , 20.f , 21.f , 22.f , 23.f ,
24.f , 25.f , 26.f , 27.f , 28.f , 29.f , 30.f , 31.f , 32.f , 33.f , 34.f , 35.f , 36.f , 37.f , 38.f , 39.f , 40.f , 41.f , 42.f ,
43.f , 44.f , 45.f , 46.f , 47.f , 48.f , 49.f , 50.f , 51.f , 52.f , 53.f , 54.f , 55.f , 56.f , 57.f , 58.f , 59.f , 60.f , 61.f ,
2019-06-06 14:21:15 +02:00
62.f , 63.f , 64.f , 65.f , 66.f , 67.f , 68.f , 69.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { 1 , 3 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test12 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } ) ;
x . linspace ( - 50. ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 3 , 5 } , { - 0.7f , - 0.6f , - 0.5f , - 0.4f , - 0.3f , - 0.2f , - 0.1f , 0.f , 0.1f , 0.2f , 0.3f , 0.4f , 0.5f , 0.6f , 0.7f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } , { 35.f , 29.4f , 24.f , 18.8f , 13.8f , 31.5f , 26.4f , 21.5f , 16.8f , 12.3f , 28.f , 23.4f , 19.f , 14.8f , 10.8f , 24.5f , 20.4f , 16.5f , 12.8f ,
2019-06-15 13:34:34 +02:00
9.3f , 6.f , 2.9f , 0.f , - 2.7f , - 5.2f , 5.f , 2.4f , 0.f , - 2.2f , - 4.2f , 4.f , 1.9f , 0.f , - 1.7f , - 3.2f , 3.f , 1.4f , 0.f , - 1.2f ,
- 2.2f , - 3.f , - 3.6f , - 4.f , - 4.2f , - 4.2f , - 1.5f , - 1.6f , - 1.5f , - 1.2f , - 0.7f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f ,
9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 17.f , 18.f , 19.f , 20.f , 21.f , 22.f , 23.f , 24.f , 25.f , 26.f , 27.f , 28.f , 29.f , 30.f ,
31.f , 32.f , 33.f , 34.f , 35.f , 36.f , 37.f , 38.f , 39.f , 40.f , 41.f , 42.f , 43.f , 44.f , 45.f , 46.f , 47.f , 48.f , 49.f , 50.f , 51.f , 52.f ,
2019-06-06 14:21:15 +02:00
53.f , 54.f , 55.f , 56.f , 57.f , 58.f , 59.f , 60.f , 61.f , 62.f , 63.f , 64.f , 65.f , 66.f , 67.f , 68.f , 69.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { - 1 , 2 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test13 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } ) ;
x . linspace ( - 50. ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 5 , 3 } , { - 0.7f , - 0.6f , - 0.5f , - 0.4f , - 0.3f , - 0.2f , - 0.1f , 0.f , 0.1f , 0.2f , 0.3f , 0.4f , 0.5f , 0.6f , 0.7f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } , { 35.f , 29.4f , 24.f , 18.8f , 13.8f , 31.5f , 26.4f , 21.5f , 16.8f , 12.3f , 28.f , 23.4f , 19.f , 14.8f , 10.8f , 24.5f , 20.4f , 16.5f , 12.8f ,
2019-06-15 13:34:34 +02:00
9.3f , 6.f , 2.9f , 0.f , - 2.7f , - 5.2f , 5.f , 2.4f , 0.f , - 2.2f , - 4.2f , 4.f , 1.9f , 0.f , - 1.7f , - 3.2f , 3.f , 1.4f , 0.f , - 1.2f ,
- 2.2f , - 3.f , - 3.6f , - 4.f , - 4.2f , - 4.2f , - 1.5f , - 1.6f , - 1.5f , - 1.2f , - 0.7f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f ,
9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 17.f , 18.f , 19.f , 20.f , 21.f , 22.f , 23.f , 24.f , 25.f , 26.f , 27.f , 28.f , 29.f , 30.f ,
31.f , 32.f , 33.f , 34.f , 35.f , 36.f , 37.f , 38.f , 39.f , 40.f , 41.f , 42.f , 43.f , 44.f , 45.f , 46.f , 47.f , 48.f , 49.f , 50.f , 51.f , 52.f ,
2019-06-06 14:21:15 +02:00
53.f , 54.f , 55.f , 56.f , 57.f , 58.f , 59.f , 60.f , 61.f , 62.f , 63.f , 64.f , 65.f , 66.f , 67.f , 68.f , 69.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { - 1 , 2 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_test14 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } ) ;
x . linspace ( - 50. ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 2 , 10 } , { - 0.7f , - 0.6f , - 0.5f , - 0.4f , - 0.3f , - 0.2f , - 0.1f , 0.f , 0.1f , 0.2f , 0.3f , 0.4f , 0.5f , 0.6f , 0.7f , 0.8f , 0.9f , 1.f , 1.1f , 1.2f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } , { 35.f , 29.4f , 24.f , 18.8f , 13.8f , 9.f , 4.4f , 0.f , - 4.2f , - 8.2f , - 12.f , - 15.6f , - 19.f , - 22.2f , - 25.2f , - 28.f , - 30.6f ,
2019-06-15 13:34:34 +02:00
- 33.f , - 35.2f , - 37.2f , 21.f , 17.4f , 14.f , 10.8f , 7.8f , 5.f , 2.4f , 0.f , - 2.2f , - 4.2f , - 6.f , - 7.6f , - 9.f , - 10.2f ,
- 11.2f , - 12.f , - 12.6f , - 13.f , - 13.2f , - 13.2f , 7.f , 5.4f , 4.f , 2.8f , 1.8f , 1.f , 0.4f , 0.f , - 0.2f , - 0.2f , 0.f ,
1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f , 12.f , 13.f , 14.f , 15.f , 16.f , 17.f , 18.f ,
19.f , 20.f , 21.f , 22.f , 23.f , 24.f , 25.f , 26.f , 27.f , 28.f , 29.f , 30.f , 31.f , 32.f , 33.f , 34.f , 35.f , 36.f ,
37.f , 38.f , 39.f , 40.f , 41.f , 42.f , 43.f , 44.f , 45.f , 46.f , 47.f , 48.f , 49.f , 50.f , 51.f , 52.f , 53.f , 54.f ,
2019-06-06 14:21:15 +02:00
55.f , 56.f , 57.f , 58.f , 59.f , 60.f , 61.f , 62.f , 63.f , 64.f , 65.f , 66.f , 67.f , 68.f , 69.f } ) ;
nd4j : : ops : : prelu op ;
auto result = op . execute ( { & x , & alpha } , { } , { - 2 } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , thresholdedrelu_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const float theta = 2.f ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12.f , - 11.f , - 10.f , - 9.f , - 8.f , - 7.f , - 6.f , - 5.f , - 4.f , - 3.f , - 2.f , - 1.f , 0.f , 1.f , 2.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 3.f , 4.f , 5.f , 6.f , 7.f , 8.f , 9.f , 10.f , 11.f } ) ;
nd4j : : ops : : thresholdedrelu op ;
auto result = op . execute ( { & x } , { theta } , { } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto output = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , thresholdedrelu_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const float theta = - 2.f ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 0.f , - 4.f , - 10.f , - 8.f , 0.f , - 9.f , - 8.f , 5.f , 6.f , 6.f , 9.f , 6.f , - 8.f , 5.f , 10.f , - 2.f , 3.f , - 7.f , 4.f , - 8.f , - 4.f , - 9.f , - 9.f , 3.f } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 0.f , 5.f , 6.f , 6.f , 9.f , 6.f , 0.f , 5.f , 10.f , 0.f , 3.f , 0.f , 4.f , 0.f , 0.f , 0.f , 0.f , 3.f } ) ;
nd4j : : ops : : thresholdedrelu op ;
auto result = op . execute ( { & x } , { theta } , { } , { } , false , nd4j : : DataType : : DOUBLE ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
2019-06-15 13:34:34 +02:00
auto output = result - > at ( 0 ) ;
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( exp . isSameShape ( output ) ) ;
ASSERT_TRUE ( exp . equalsTo ( output ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_bp_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12. , - 11. , - 10. , - 9. , - 8. , - 7. , - 6. , - 5. , - 4. , - 3. , - 2. , - 1. , 0.5 , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. , 10. , 11. } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 3 , 4 } , { - 0.6 , - 0.5 , - 0.4 , - 0.3 , - 0.2 , - 0.1 , 0.5 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & alpha } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & alpha , & dLdO } , { } , { } ) ;
nd4j : : ops : : prelu opFF ;
nd4j : : ops : : prelu_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_bp_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { - 12. , - 11. , - 10. , - 9. , - 8. , - 7. , - 6. , - 5. , - 4. , - 3. , - 2. , - 1. , 0.5 , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. , 10. , 11. } ) ;
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 4 } , { - 0.6 , 2. , 4. , - 1. } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & alpha } , { } , { 1 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & alpha , & dLdO } , { } , { 1 } ) ;
nd4j : : ops : : prelu opFF ;
nd4j : : ops : : prelu_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_bp_test3 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 , 5 } ) ;
x . linspace ( - 30. ) ;
x . p ( 30 , 0.5 ) ; // avoid zero, since it is points of discontinuity for prelu
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 5 , 3 } , { - 0.7 , - 0.6 , - 0.5 , - 0.4 , - 0.3 , - 0.2 , - 0.1 , 0.5 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 , 5 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & alpha } , { } , { - 1 , 2 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & alpha , & dLdO } , { } , { - 1 , 2 } ) ;
nd4j : : ops : : prelu opFF ;
nd4j : : ops : : prelu_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , prelu_bp_test4 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } ) ;
x . linspace ( - 50. ) ;
x . p ( 50 , 0.5 ) ; // avoid zero, since it is points of discontinuity for prele
auto alpha = NDArrayFactory : : create < double > ( ' c ' , { 2 , 10 } , { - 0.7 , - 0.6 , - 0.5 , - 0.4 , - 0.3 , - 0.2 , - 0.1 , 0.25 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1. , 1.1 , 1.2 } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 , 5 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & alpha } , { } , { - 2 } ) ;
const OpArgsHolder argsHolderBP ( { & x , & alpha , & dLdO } , { } , { - 2 } ) ;
nd4j : : ops : : prelu opFF ;
nd4j : : ops : : prelu_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , thresholdedrelu_bp_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
const double theta = 0.15 ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 1.2 , 1.1 , 1. , 0.9 , 0.8 , - 0.7 , - 0.6 , - 0.5 , - 0.4 , - 0.3 , - 0.2 , - 0.1 , 0. , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , - 0.9 , - 1.0 , - 1.1 } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
const OpArgsHolder argsHolderFF ( { & x } , { theta } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & dLdO } , { theta } , { } ) ;
nd4j : : ops : : thresholdedrelu opFF ;
nd4j : : ops : : thresholdedrelu_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 4 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 0.1f , 0.4f , 0.9f , 1.6f , 0.5f , 1.2f , 2.1f , 3.2f , 0.9f , 2.f , 3.3f , 4.8f , 1.3f , 2.8f , 4.5f , 6.4f , 1.7f , 3.6f , 5.7f , 8.f , 2.1f , 4.4f , 6.9f , 9.6f } ) ;
x . linspace ( 1.f ) ;
y . linspace ( 0.1f , 0.1f ) ;
nd4j : : ops : : multiply op ;
auto result = op . execute ( { & x , & y } , { } , { } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( 0.1 ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 0.1f , 0.2f , 0.3f , 0.4f , 0.5f , 0.6f , 0.7f , 0.8f , 0.9f , 1.f , 1.1f , 1.2f , 1.3f , 1.4f , 1.5f , 1.6f , 1.7f , 1.8f , 1.9f , 2.f , 2.1f , 2.2f , 2.3f , 2.4f } ) ;
x . linspace ( 1.f ) ;
// y.linspace(0.1f, 0.1f);
nd4j : : ops : : multiply op ;
auto result = op . execute ( { & y , & x } , { } , { } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_test3 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 4 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 3 , 1 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 0.1f , 0.2f , 0.3f , 0.4f , 0.2f , 0.4f , 0.6f , 0.8f , 0.3f , 0.6f , 0.9f , 1.2f , 0.5f , 0.6f , 0.7f , 0.8f , 1.f , 1.2f , 1.4f , 1.6f , 1.5f , 1.8f , 2.1f , 2.4f } ) ;
x . linspace ( 1.f ) ;
y . linspace ( 0.1f , 0.1f ) ;
nd4j : : ops : : multiply op ;
auto result = op . execute ( { & x , & y } , { } , { } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_test4 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } ) ;
auto y = NDArrayFactory : : create < double > ( 0.1f ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } , { 0.1f } ) ;
2019-06-15 13:34:34 +02:00
x . linspace ( 1.f ) ;
2019-06-06 14:21:15 +02:00
nd4j : : ops : : multiply op ;
auto result = op . execute ( { & x , & y } , { } , { } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_test5 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( 1.f ) ;
auto y = NDArrayFactory : : create < double > ( 0.1f ) ;
auto exp = NDArrayFactory : : create < double > ( 0.1f ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
nd4j : : ops : : multiply op ;
auto result = op . execute ( { & x , & y } , { } , { } ) ;
ASSERT_EQ ( ND4J_STATUS_OK , result - > status ( ) ) ;
auto z = result - > at ( 0 ) ;
ASSERT_TRUE ( exp . isSameShape ( z ) ) ;
ASSERT_TRUE ( exp . equalsTo ( z ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test1 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } , { 100. } ) ;
auto y = NDArrayFactory : : create < double > ( 0.1 ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
auto resFF = opFF . execute ( { & x , & y } , { } , { } ) ;
auto resBP = opBP . execute ( { & x , & y , & dLdz } , { } , { } ) ;
// resFF->at(0)->printIndexedBuffer("Multiply 1x1");
// resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x");
// resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
delete resFF ;
delete resBP ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test2 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 1. , 2. , 3. , 4. } ) ;
auto y = NDArrayFactory : : create < double > ( 0.1 ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test3 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto y = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 1. , 2. , 3. , 4. } ) ;
auto x = NDArrayFactory : : create < double > ( 0.1 ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test4 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 1. , 2. , 3. , 4. } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 0.1 , 0.2 , 0.3 , 0.4 } ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test5 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 1. , 2. , 3. , 4. } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 2 } , { 0.1 , 0.2 } ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test6 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto y = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } , { 1. , 2. , 3. , 4. } ) ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 } , { 0.1 , 0.2 } ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 2 , 2 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test7 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto y = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } , { 1. , 2. , 3. , 4. , 5. , 6. } ) ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 } , { 0.1 , 0.2 } ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 } ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , multiply_bp_test8 ) {
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
auto y = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 4 } ) ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 , 4 } ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
x . linspace ( 1. , 0.5 ) ;
y . linspace ( 0.1 , 0.05 ) ;
const OpArgsHolder argsHolderFF ( { & x , & y } , { } , { } ) ;
const OpArgsHolder argsHolderBP ( { & x , & y , & dLdz } , { } , { } ) ;
nd4j : : ops : : multiply opFF ;
nd4j : : ops : : multiply_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , Floormod_BP_Test_2 ) {
auto y = NDArrayFactory : : create < double > ( ' c ' , { 10 , 10 } ) ;
auto x = NDArrayFactory : : create < double > ( ' c ' , { 10 , 10 } ) ;
auto dLdz = NDArrayFactory : : create < double > ( ' c ' , { 10 , 10 } ) ;
//auto eps = NDArrayFactory::create<double>('c', {10, 10});
x . linspace ( 4 ) ; //2., 2.0);
y . linspace ( 3 ) ;
dLdz . linspace ( 1 ) ;
// const OpArgsHolder argsHolderFF({&x, &y}, {}, {});
// const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {});
// nd4j::ops::floormod opFF;
// auto resFF = opFF.execute({&x, &y}, {}, {});
// resFF->at(0)->printIndexedBuffer("FF floormod");
// delete resFF;
nd4j : : ops : : floormod_bp opBP ;
auto resBP = opBP . execute ( { & x , & y , & dLdz } , { } , { } ) ;
ASSERT_TRUE ( resBP - > status ( ) = = ND4J_STATUS_OK ) ;
// resBP->at(0)->printIndexedBuffer("BP floormod /dx");
// resBP->at(1)->printIndexedBuffer("BP floormod /dy");
ASSERT_TRUE ( dLdz . equalsTo ( resBP - > at ( 0 ) ) ) ;
ASSERT_TRUE ( dLdz . equalsTo ( resBP - > at ( 1 ) ) ) ;
delete resBP ;
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
// ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , Dynamic_Partition_BP_1 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } ) ;
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
auto y = NDArrayFactory : : create < int > ( ' c ' , { 2 , 3 } , { 0 , 1 , 2 , 1 , 0 , 2 } ) ;
2019-06-06 14:21:15 +02:00
auto dLdzX = NDArrayFactory : : create < double > ( ' c ' , { 2 , 4 } ) ;
auto dLdzY = NDArrayFactory : : create < double > ( ' c ' , { 2 , 4 } ) ;
auto dLdzZ = NDArrayFactory : : create < double > ( ' c ' , { 2 , 4 } ) ;
2019-07-12 10:51:51 +02:00
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 4 } , { 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 , 2 , 2 , 2 , 2 , 1 , 1 , 1 , 1 , 3 , 3 , 3 , 3 } ) ;
2019-06-06 14:21:15 +02:00
x . linspace ( 1 ) ;
2019-07-12 10:51:51 +02:00
// dLdzX.linspace(1);
// dLdzY.linspace(2);
// dLdzZ.linspace(3);
dLdzX . assign ( 1 ) ;
dLdzY . assign ( 2 ) ;
dLdzZ . assign ( 3 ) ;
2019-06-06 14:21:15 +02:00
nd4j : : ops : : dynamic_partition op1 ;
auto res1 = op1 . execute ( { & x , & y } , { } , { 3 } ) ;
nd4j : : ops : : dynamic_partition_bp op2 ;
2019-07-12 10:51:51 +02:00
auto res2 = op2 . execute ( { & x , & y , & dLdzX , & dLdzY , & dLdzZ } , { } , { 3 } ) ;
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( res2 - > status ( ) = = ND4J_STATUS_OK ) ;
ASSERT_TRUE ( res2 - > size ( ) = = 2 ) ;
2019-07-12 10:51:51 +02:00
// printf("How many: %ul\n", res2->size());
// res2->at(0)->printBuffer("Ouputput0");
// res2->at(1)->printBuffer("Ouputput1");
ASSERT_TRUE ( res2 - > at ( 0 ) - > equalsTo ( exp ) ) ;
2019-06-06 14:21:15 +02:00
delete res1 ;
delete res2 ;
}
2019-07-12 10:51:51 +02:00
//////////////////////////////////////////////////////////////////////
//TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) {
//
// auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
// auto y = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 1, 0, 2});
// auto dLdzX = NDArrayFactory::create<double>('c', {2, 4});
// auto dLdzY = NDArrayFactory::create<double>('c', {2, 4});
// auto dLdzZ = NDArrayFactory::create<double>('c', {2, 4});
// x.linspace(1);
// dLdzX.linspace(1);
// dLdzY.linspace(1);
// dLdzZ.linspace(1);
//
// const OpArgsHolder argsHolderFF({&x, &y}, {}, {3});
// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
//
// nd4j::ops::dynamic_partition opFF;
// nd4j::ops::dynamic_partition_bp opBP;
//
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
//
// ASSERT_TRUE(isGradCorrect);
//}
2019-06-06 14:21:15 +02:00
////////////////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , Floormod_BP_Test_4 ) {
auto x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 3 } , { 2.0 , 6.0 , - 3.0 , 2.0 , 6.0 , - 3.0 } ) ;
auto y = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 } , { - 3.0 , 2.0 , - 2.0 } ) ;
auto exp = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 } , { - 1. , 0. , - 1. } ) ;
auto eps = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 3 } ) ;
eps . assign ( 1.f ) ;
nd4j : : ops : : floormod_bp op ;
auto result = op . execute ( { & x , & y , & eps } , { } , { } ) ;
ASSERT_TRUE ( result - > size ( ) = = 2 ) ;
auto gradX = result - > at ( 0 ) ;
auto gradY = result - > at ( 1 ) ;
// gradX->printIndexedBuffer("gradX");
// gradY->printIndexedBuffer("gradY");
ASSERT_TRUE ( exp . isSameShape ( gradY ) ) ;
ASSERT_TRUE ( exp . equalsTo ( gradY ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , batchnorm_bp_test1 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 3 , 2 } ) ;
auto mean = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 } ) ;
auto variance = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 1 , 3 , 2 } ) ;
auto gamma = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } ) ;
auto beta = NDArrayFactory : : create < double > ( ' c ' , { 1 , 2 } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 , 3 , 2 } ) ;
input . linspace ( 0.1 , 0.1 ) ;
mean . assign ( 1. ) ;
variance . assign ( 0.5 ) ;
gamma . assign ( 1.2 ) ;
beta . assign ( 1. ) ;
const OpArgsHolder argsHolderFF ( { & input , & mean , & variance , & gamma , & beta } , { 1e-5 } , { 1 , 1 } ) ;
const OpArgsHolder argsHolderBP ( { & input , & mean , & variance , & gamma , & beta , & dLdO } , { 1e-5 } , { 1 , 1 } ) ;
nd4j : : ops : : batchnorm opFF ;
nd4j : : ops : : batchnorm_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , batchnorm_bp_test2 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 , 3 , 2 } ) ;
auto mean = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 } ) ;
auto variance = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 1 , 3 , 1 } ) ;
auto gamma = NDArrayFactory : : create < double > ( ' c ' , { 1 , 1 } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 , 3 , 2 } ) ;
input . linspace ( 0.1 , 0.1 ) ;
mean . assign ( 1. ) ;
variance . assign ( 0.5 ) ;
gamma . assign ( 1.2 ) ;
const OpArgsHolder argsHolderFF ( { & input , & mean , & variance , & gamma } , { 1e-5 } , { 1 , 0 } ) ;
const OpArgsHolder argsHolderBP ( { & input , & mean , & variance , & gamma , & dLdO } , { 1e-5 } , { 1 , 0 } ) ;
nd4j : : ops : : batchnorm opFF ;
nd4j : : ops : : batchnorm_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , batchnorm_bp_test3 ) {
auto input = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 1 , 3 } ) ;
auto mean = NDArrayFactory : : create < double > ( ' c ' , { 1 , 3 , 2 , 1 } ) ;
auto variance = NDArrayFactory : : create < double > ( ' c ' , { 2 , 1 , 2 , 3 } ) ;
auto dLdO = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 2 , 3 } ) ;
input . linspace ( 0.1 , 0.1 ) ;
mean . assign ( 1. ) ;
variance . assign ( 0.5 ) ;
const OpArgsHolder argsHolderFF ( { & input , & mean , & variance } , { 1e-5 } , { 0 , 0 } ) ;
const OpArgsHolder argsHolderBP ( { & input , & mean , & variance , & dLdO } , { 1e-5 } , { 0 , 0 } ) ;
nd4j : : ops : : batchnorm opFF ;
nd4j : : ops : : batchnorm_bp opBP ;
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP ) ;
ASSERT_TRUE ( isGradCorrect ) ;
}
/*
////////////////////////////////////////////////////////////////////
2019-07-20 07:58:44 +02:00
TEST_F ( DeclarableOpsTests9 , gru_cell_bp_test1 ) {
2019-06-06 14:21:15 +02:00
const int bS = 2 ;
const int iS = 3 ;
const int nU = 4 ;
NDArray x ( ' c ' , { bS , iS } , nd4j : : DataType : : DOUBLE ) ;
2019-07-20 07:58:44 +02:00
NDArray hi ( ' c ' , { bS , nU } , nd4j : : DataType : : DOUBLE ) ;
NDArray W ( ' c ' , { iS + nU , 2 * nU } , nd4j : : DataType : : DOUBLE ) ;
NDArray Wc ( ' c ' , { iS + nU , nU } , nd4j : : DataType : : DOUBLE ) ;
NDArray b ( ' c ' , { 2 * nU } , nd4j : : DataType : : DOUBLE ) ;
NDArray bc ( ' c ' , { nU } , nd4j : : DataType : : DOUBLE ) ;
NDArray dLdr ( ' c ' , { bS , nU } , nd4j : : DataType : : DOUBLE ) ;
NDArray dLdu ( ' c ' , { bS , nU } , nd4j : : DataType : : DOUBLE ) ;
NDArray dLdc ( ' c ' , { bS , nU } , nd4j : : DataType : : DOUBLE ) ;
2019-06-06 14:21:15 +02:00
NDArray dLdh ( ' c ' , { bS , nU } , nd4j : : DataType : : DOUBLE ) ;
2019-07-20 07:58:44 +02:00
x . linspace ( - 5 , 0.5 ) ;
hi = 1. ;
W = 0.003 ;
Wc = 0.006 ;
b = 0.5 ;
bc = 0.35 ;
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
const OpArgsHolder argsHolderFF ( { & x , & hi , & W , & Wc , & b , & bc } , { } , { } ) ;
nd4j : : ops : : gruCell op ;
auto results = op . execute ( argsHolderFF ) ;
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
ASSERT_EQ ( ND4J_STATUS_OK , results - > status ( ) ) ;
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
auto u = results - > at ( 1 ) ; // [bS, nU]
auto c = results - > at ( 2 ) ; // [bS, nU]
auto h = results - > at ( 3 ) ; // [bS, nU]
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
dLdh = 1. ; // SUM loss
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
NDArray Wch = Wc ( { iS , iS + nU , 0 , 0 } ) ; // [nU, nU]
NDArray dhdc = 1. - * u ;
NDArray dhdu = hi - * c ;
NDArray dcdZc = 1. - * c * * c ;
dLdc . assign ( dLdh * dhdc ) ;
dLdu . assign ( dLdh * dhdu ) ;
dLdr . assign ( mmul ( dLdc * dcdZc * hi , Wch . transpose ( ) ) ) ;
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
delete results ;
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
const OpArgsHolder argsHolderBP ( { & x , & hi , & W , & Wc , & b , & bc , & dLdr , & dLdu , & dLdc , & dLdh } , { } , { } ) ;
2019-06-06 14:21:15 +02:00
nd4j : : ops : : gruCell opFF ;
nd4j : : ops : : gruCell_bp opBP ;
2019-07-20 07:58:44 +02:00
const bool isGradCorrect = GradCheck : : checkGrad ( opFF , opBP , argsHolderFF , argsHolderBP , { 1 , 1 , 1 , 1 , 1 , 1 } , { 0. , 1. } , nd4j : : GradCheck : : LossFunc : : SUM , true ) ;
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( isGradCorrect ) ;
}
*/
2019-07-20 07:58:44 +02:00
2019-06-06 14:21:15 +02:00
////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , Cholesky_Test_1 ) {
NDArray x = NDArrayFactory : : create < double > ( ' c ' , { 3 , 3 } , { 4 , 12 , - 16 , 12 , 37 , - 43 , - 16 , - 43 , 98 } ) ;
NDArray exp = NDArrayFactory : : create < double > ( ' c ' , { 3 , 3 } , { 2. , 0. , 0. , 6. , 1. , 0. , - 8. , 5. , 3. } ) ;
nd4j : : ops : : cholesky op ;
auto result = op . execute ( { & x } , { } , { } ) ;
ASSERT_EQ ( result - > status ( ) , ND4J_STATUS_OK ) ;
auto res = result - > at ( 0 ) ;
2019-07-12 10:51:51 +02:00
// res->printIndexedBuffer("Output for Cholesky1");
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( exp . equalsTo ( res ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , Cholesky_Test_2 ) {
2019-07-12 10:51:51 +02:00
NDArray x = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 3 } , { 4 , 12 , - 16 , 12 , 37 , - 43 , - 16 , - 43 , 98 , 1 , 1 , 1 , 1 , 2 , 2 , 1 , 2. , 6 } ) ;
NDArray exp = NDArrayFactory : : create < double > ( ' c ' , { 2 , 3 , 3 } , { 2. , 0. , 0. , 6. , 1. , 0. , - 8. , 5. , 3. , 1. , 0. , 0. , 1. , 1. , 0 , 1. , 1. , 2. } ) ;
nd4j : : ops : : cholesky op ;
auto result = op . execute ( { & x } , { } , { } ) ;
ASSERT_EQ ( result - > status ( ) , ND4J_STATUS_OK ) ;
auto res = result - > at ( 0 ) ;
// res->printIndexedBuffer("Output for Cholesky 2");
ASSERT_TRUE ( exp . equalsTo ( res ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////
TEST_F ( DeclarableOpsTests9 , Cholesky_Test_3 ) {
2019-06-06 14:21:15 +02:00
NDArray x = NDArrayFactory : : create < float > ( ' c ' , { 2 , 3 , 3 } , { 4 , 12 , - 16 , 12 , 37 , - 43 , - 16 , - 43 , 98 , 1 , 1 , 1 , 1 , 2 , 2 , 1 , 2. , 6 } ) ;
NDArray exp = NDArrayFactory : : create < float > ( ' c ' , { 2 , 3 , 3 } , { 2. , 0. , 0. , 6. , 1. , 0. , - 8. , 5. , 3. , 1. , 0. , 0. , 1. , 1. , 0 , 1. , 1. , 2. } ) ;
nd4j : : ops : : cholesky op ;
auto result = op . execute ( { & x } , { } , { } ) ;
ASSERT_EQ ( result - > status ( ) , ND4J_STATUS_OK ) ;
auto res = result - > at ( 0 ) ;
2019-07-12 10:51:51 +02:00
// res->printIndexedBuffer("Output for Cholesky 3");
2019-06-06 14:21:15 +02:00
ASSERT_TRUE ( exp . equalsTo ( res ) ) ;
delete result ;
}
////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests9, gru_bp_test1) {
// const int time = 5;
// const int bS = 2;
// const int iS = 3;
// const int nU = 4;
// NDArray<double> x ('c', {time, bS, iS});
// NDArray<double> h0 ('c', {bS, nU});
// NDArray<double> Wx ('c', {iS, 3*nU});
// NDArray<double> Wh ('c', {nU, 3*nU});
// NDArray<double> b ('c', {3*nU});
// NDArray<double> dLdh ('c', {time, bS, nU});
// x.linspace(0.5, 0.5);
// h0 = 1.;
// Wx = 0.003;
// Wh = 0.006;
// b = 0.5;
// const OpArgsHolder<double> argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
// const OpArgsHolder<double> argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
// nd4j::ops::gru<double> opFF;
// nd4j::ops::gru_bp<double> opBP;
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
// ASSERT_TRUE(isGradCorrect);
// }
//