- new NDArray methods like()/ulike() (#77)

- fix for depthwise_conv2d_bp + special test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-23 08:38:00 +03:00 committed by AlexDBlack
parent 00e6296140
commit 4f2dae23a1
4 changed files with 44 additions and 1 deletions

View File

@ -212,6 +212,20 @@ namespace nd4j {
*/ */
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false); NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
/**
* This method returns new array with the same shape & data type
* @return
*/
NDArray like();
/**
* This method returns new uninitialized array with the same shape & data type
* @return
*/
NDArray ulike();
/** /**
* this constructor creates new NDArray with shape matching "other" array, * this constructor creates new NDArray with shape matching "other" array,
* doesn't copy "other" elements into new array !!! * doesn't copy "other" elements into new array !!!

View File

@ -4129,6 +4129,19 @@ Nd4jLong NDArray::getOffset(const Nd4jLong i) const {
return shape::getIndexOffset(i, _shapeInfo, lengthOf()); return shape::getIndexOffset(i, _shapeInfo, lengthOf());
} }
NDArray NDArray::like() {
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
return res;
}
NDArray NDArray::ulike() {
// FIXME: it should be non-memset array
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
return res;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
NDArray* NDArray::diagonal(const char type) const { NDArray* NDArray::diagonal(const char type) const {

View File

@ -990,8 +990,9 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
if(gradBR != gradB) if(gradBR != gradB)
delete gradB; delete gradBR;
} }
//----- calculation of gradI -----// //----- calculation of gradI -----//

View File

@ -156,6 +156,21 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
auto b = NDArrayFactory::create<float>('c', {1, 16});
auto grad = NDArrayFactory::create<float>('c', {4, 16, 64, 64});
auto gradI = in.like();
auto gradW = w.like();
auto gradB = b.like();
nd4j:ops::depthwise_conv2d_bp op;
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {});
ASSERT_EQ(Status::OK(), status);
}
TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
auto a = NDArrayFactory::create<double>('c', {1, 3}); auto a = NDArrayFactory::create<double>('c', {1, 3});
auto b = NDArrayFactory::create<double>('c', {1, 4}); auto b = NDArrayFactory::create<double>('c', {1, 4});