- new NDArray methods like()/ulike() (#77)
- fix for depthwise_conv2d_bp + special test Signed-off-by: raver119 <raver119@gmail.com>master
parent
00e6296140
commit
4f2dae23a1
|
@ -212,6 +212,20 @@ namespace nd4j {
|
|||
*/
|
||||
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||
|
||||
|
||||
/**
|
||||
* This method returns new array with the same shape & data type
|
||||
* @return
|
||||
*/
|
||||
NDArray like();
|
||||
|
||||
/**
|
||||
* This method returns new uninitialized array with the same shape & data type
|
||||
* @return
|
||||
*/
|
||||
NDArray ulike();
|
||||
|
||||
|
||||
/**
|
||||
* this constructor creates new NDArray with shape matching "other" array,
|
||||
* doesn't copy "other" elements into new array !!!
|
||||
|
|
|
@ -4129,6 +4129,19 @@ Nd4jLong NDArray::getOffset(const Nd4jLong i) const {
|
|||
return shape::getIndexOffset(i, _shapeInfo, lengthOf());
|
||||
}
|
||||
|
||||
NDArray NDArray::like() {
|
||||
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
NDArray NDArray::ulike() {
|
||||
// FIXME: it should be non-memset array
|
||||
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArray::diagonal(const char type) const {
|
||||
|
||||
|
|
|
@ -990,8 +990,9 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
if(gradB->rankOf() == 2)
|
||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
|
||||
|
||||
if(gradBR != gradB)
|
||||
delete gradB;
|
||||
delete gradBR;
|
||||
}
|
||||
|
||||
//----- calculation of gradI -----//
|
||||
|
|
|
@ -156,6 +156,21 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
||||
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
||||
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
||||
auto b = NDArrayFactory::create<float>('c', {1, 16});
|
||||
auto grad = NDArrayFactory::create<float>('c', {4, 16, 64, 64});
|
||||
|
||||
auto gradI = in.like();
|
||||
auto gradW = w.like();
|
||||
auto gradB = b.like();
|
||||
|
||||
nd4j:ops::depthwise_conv2d_bp op;
|
||||
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_matmul_bp_1) {
|
||||
auto a = NDArrayFactory::create<double>('c', {1, 3});
|
||||
auto b = NDArrayFactory::create<double>('c', {1, 4});
|
||||
|
|
Loading…
Reference in New Issue