[WIP] more fixes (#159)

* Added test for MatrixInverse with double input. Fixed matrixDeterminantKernel.

* Fixed kernels to avoid waste templating.

* Fixed logDeterminant kernel.

* Refactored type check for lup'

* - decrease blockDim value for zeta op

Signed-off-by: Yurii <yurii@skymind.io>

* Added print for compound matrix with CUDA.

* Refactored upper matrix invertion kernels.

* - provide move constructor and move assignment operator for OpArgsHoder class

Signed-off-by: Yurii <yurii@skymind.io>

* Refactored usage of launch context.

* - add test for mergemax

Signed-off-by: Yurii <yurii@skymind.io>

* get rid of AveragingArrayProxy

Signed-off-by: raver119 <raver119@gmail.com>

* Refactoring of LUP inversion.

* Added prints for invertion.

* - add OpArgsHolder copy constructor and assignment operator

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for lower inversion

* - fix bug in upsampling2d/3d_bp op

Signed-off-by: Yurii <yurii@skymind.io>

* Added expensive printfs to kernel.

* Refactored expensive kernel prints.

* Refactored expensive printfs

* - remove nullify

Signed-off-by: Yurii <yurii@skymind.io>

* Eliminated waste prints with tests.

* upsampling2d_bp test

Signed-off-by: raver119 <raver119@gmail.com>

* test updated

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 19:20:50 +03:00 committed by GitHub
parent 99cdf6d42b
commit f03b0ee78f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1241 additions and 1009 deletions

View File

@ -1,59 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef DEV_TESTS_AVERAGINGARRAYPROXY_H
#define DEV_TESTS_AVERAGINGARRAYPROXY_H
#include "NDArray.h"
#include <utility>
#include <map>
#include <vector>
#include <mutex>
namespace nd4j {
class ND4J_EXPORT AveragingArrayProxy {
protected:
NDArray *_original;
std::map<std::pair<int,int>, NDArray*> _writeables;
std::map<int, std::vector<NDArray*>> _writeablesLinear;
std::vector<int> _rows;
std::vector<NDArray*> _references;
std::mutex _lock;
public:
explicit AveragingArrayProxy(NDArray *original);
~AveragingArrayProxy();
NDArray* readable(int row, int key);
NDArray* writeable(int row, int key);
bool isEmpty();
bool writeableExists(std::pair<int, int> &key);
bool writeableExists(int row, int key);
bool collapseWrites();
};
}
#endif //DEV_TESTS_AVERAGINGARRAYPROXY_H

View File

@ -26,27 +26,42 @@
#include <dll.h>
namespace nd4j {
class ND4J_EXPORT OpArgsHolder {
private:
private:
std::vector<NDArray*> _inArrs = std::vector<NDArray*>();
std::vector<double> _tArgs = std::vector<double>();
std::vector<Nd4jLong> _iArgs = std::vector<Nd4jLong>();
std::vector<bool> _bArgs = std::vector<bool>();
std::vector<double> _tArgs = std::vector<double>();
std::vector<Nd4jLong> _iArgs = std::vector<Nd4jLong>();
std::vector<bool> _bArgs = std::vector<bool>();
std::vector<bool> _isArrAlloc = std::vector<bool>();
int _numInArrs = _inArrs.size();
int _numTArgs = _tArgs.size();
int _numIArgs = _iArgs.size();
int _numBArgs = _bArgs.size();
std::vector<bool> _isArrAlloc = std::vector<bool>();
public:
OpArgsHolder() = delete;
// default constructor
OpArgsHolder();
OpArgsHolder(const std::vector<NDArray*>& inArrs, const std::vector<double>& tArgs = std::vector<double>(), const std::vector<Nd4jLong>& iArgs = std::vector<Nd4jLong>(), const std::vector<bool>& bArgs = std::vector<bool>())
: _inArrs(inArrs), _tArgs(tArgs), _iArgs(iArgs), _bArgs(bArgs) { }
// copy constructor
OpArgsHolder(const OpArgsHolder& other);
// constructor
OpArgsHolder(const std::vector<NDArray*>& inArrs, const std::vector<double>& tArgs = std::vector<double>(), const std::vector<Nd4jLong>& iArgs = std::vector<Nd4jLong>(), const std::vector<bool>& bArgs = std::vector<bool>());
// move constructor
OpArgsHolder(OpArgsHolder&& other) noexcept;
// assignment operator
OpArgsHolder& operator=(const OpArgsHolder& other);
// move assignment operator
OpArgsHolder& operator=(OpArgsHolder&& other) noexcept;
const std::vector<NDArray*>& getInArrs() const
{return _inArrs; }
@ -77,8 +92,8 @@ public:
OpArgsHolder createArgsHolderForBP(const std::vector<NDArray*>& inGradArrs, const bool isInPlace = false) const;
~OpArgsHolder() noexcept;
~OpArgsHolder() noexcept;
};

View File

@ -229,7 +229,6 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
}
else if(ABC && aType == DataType::HALF) {
printf("!!!!!!!!\n");
float16 alphaH(alpha), betaH(beta);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
}

View File

@ -1,139 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../AveragingArrayProxy.h"
namespace nd4j {
AveragingArrayProxy::AveragingArrayProxy(NDArray *original) {
_original = original;
}
AveragingArrayProxy::~AveragingArrayProxy() {
for (auto v:_references)
delete v;
}
bool AveragingArrayProxy::writeableExists(std::pair<int, int> &key) {
_lock.lock();
auto r = _writeables.count(key) > 0;
_lock.unlock();
return r;
}
bool AveragingArrayProxy::writeableExists(int row, int key) {
std::pair<int, int> k(row, key);
return writeableExists(k);
}
NDArray* AveragingArrayProxy::readable(int row, int key) {
std::pair<int, int> k(row, key);
if (writeableExists(k)) {
_lock.lock();
auto r = _writeables[k];
_lock.unlock();
return r;
} else {
auto readable = (*_original)({row,row+1, 0,0});
_lock.lock();
_references.emplace_back(&readable);
_lock.unlock();
// return readable;
}
}
bool AveragingArrayProxy::isEmpty() {
return _original->isEmpty();
}
NDArray* AveragingArrayProxy::writeable(int row, int key) {
std::pair<int, int> k(row, key);
// if writeable exists - just return it
if (writeableExists(k)) {
_lock.lock();
auto r = _writeables[k];
_lock.unlock();
return r;
} else {
// if doesn't - let's create it
auto orig = (*_original)({row,row+1, 0,0});
// we don't want views here for obvious reasons
auto writeable = orig.dup('c');
_lock.lock();
_writeables[k] = writeable;
_references.emplace_back(writeable);
// storing linear reference, for future averaging step
if (_writeablesLinear.count(row) == 0) {
std::vector<NDArray*> vec;
_writeablesLinear[row] = vec;
}
_writeablesLinear[row].emplace_back(writeable);
_rows.emplace_back(row);
_lock.unlock();
return writeable;
}
}
bool AveragingArrayProxy::collapseWrites() {
if (_writeables.empty())
return false;
for (int r = 0; r < _rows.size(); r++) {
auto row = _rows[r];
auto list = _writeablesLinear.at(row);
auto originalRow = (*_original)({row,row+1, 0,0});
if (list.size() == 1) {
originalRow.assign(list.at(0));
} else {
originalRow.assign(0.0);
for (int e = 0; e < list.size(); e++)
originalRow += *(list.at(e));
originalRow /= (int) list.size();
}
}
return true;
}
}

View File

@ -23,27 +23,122 @@
namespace nd4j {
////////////////////////////////////////////////////////////////////////
// default constructor
OpArgsHolder::OpArgsHolder() {
_inArrs = std::vector<NDArray*>();
_tArgs = std::vector<double>();
_iArgs = std::vector<Nd4jLong>();
_bArgs = std::vector<bool>();
_isArrAlloc = std::vector<bool>();
_numInArrs = 0;
_numTArgs = 0;
_numIArgs = 0;
_numBArgs = 0;
}
////////////////////////////////////////////////////////////////////////
// copy constructor
OpArgsHolder::OpArgsHolder(const OpArgsHolder& other) {
throw std::runtime_error("OpArgsHolder::OpArgsHolder copy constructor: don't use me !");
}
////////////////////////////////////////////////////////////////////////
// constructor
OpArgsHolder::OpArgsHolder(const std::vector<NDArray*>& inArrs,
const std::vector<double>& tArgs,
const std::vector<Nd4jLong>& iArgs,
const std::vector<bool>& bArgs) {
_inArrs = inArrs;
_tArgs = tArgs;
_iArgs = iArgs;
_bArgs = bArgs;
_isArrAlloc = std::vector<bool>();
_numInArrs = _inArrs.size();
_numTArgs = _tArgs.size();
_numIArgs = _iArgs.size();
_numBArgs = _bArgs.size();
}
////////////////////////////////////////////////////////////////////////
// move constructor
OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept: _inArrs(std::move(other._inArrs)),
_tArgs(std::move(other._tArgs)),
_iArgs(std::move(other._iArgs)),
_bArgs(std::move(other._bArgs)),
_isArrAlloc(std::move(other._isArrAlloc)) {
other._isArrAlloc = std::vector<bool>();
_numInArrs = _inArrs.size();
_numTArgs = _tArgs.size();
_numIArgs = _iArgs.size();
_numBArgs = _bArgs.size();
}
////////////////////////////////////////////////////////////////////////
// assignment operator
OpArgsHolder& OpArgsHolder::operator=(const OpArgsHolder& other) {
throw std::runtime_error("OpArgsHolder::OpArgsHolder assignment operator: don't use me !");
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
OpArgsHolder& OpArgsHolder::operator=(OpArgsHolder&& other) noexcept {
if (this == &other)
return *this;
for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary
if(_isArrAlloc[i])
delete _inArrs[i];
_inArrs = std::move(other._inArrs);
_tArgs = std::move(other._tArgs);
_iArgs = std::move(other._iArgs);
_bArgs = std::move(other._bArgs);
_isArrAlloc = std::move(other._isArrAlloc);
other._isArrAlloc = std::vector<bool>();
_numInArrs = _inArrs.size();
_numTArgs = _tArgs.size();
_numIArgs = _iArgs.size();
_numBArgs = _bArgs.size();
return *this;
}
////////////////////////////////////////////////////////////////////////
OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector<NDArray*>& inGradArrs, const bool isInPlace) const {
const int numInGradArrs = inGradArrs.size();
OpArgsHolder result(std::vector<NDArray*>(_numInArrs + numInGradArrs, nullptr), _tArgs, _iArgs);
if(isInPlace)
result._isArrAlloc = std::vector<bool>(_numInArrs + numInGradArrs, false);
for (int i = 0; i < _numInArrs; ++i) {
if(isInPlace) {
if(isInPlace) {
result._inArrs[i] = new NDArray(*_inArrs[i]); // make copy
result._isArrAlloc[i] = true;
}
else
result._inArrs[i] = _inArrs[i];
else
result._inArrs[i] = _inArrs[i];
}
// input gradients
// input gradients
for (int i = 0; i < numInGradArrs; ++i)
result._inArrs[_numInArrs + i] = inGradArrs[i];
@ -53,11 +148,10 @@ OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector<NDArray*>& in
////////////////////////////////////////////////////////////////////////
// default destructor
OpArgsHolder::~OpArgsHolder() noexcept {
for (int i = 0; i < _isArrAlloc.size(); ++i)
if(_isArrAlloc[i])
delete _inArrs[i];
}
}

View File

@ -44,7 +44,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
auto input = INPUT_VARIABLE(i);
auto currentRank = input->rankOf();
// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy
// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank);
if(!input->isEmpty()) {

View File

@ -1147,8 +1147,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
// gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
// gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
gradI.nullify();
const T* x = gradO.bufferAsT<T>();
T* z = gradI.bufferAsT<T>();
@ -1182,8 +1180,10 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const auto zOffset = b*zStride0 + c*zStride1 + h*zStride2 + w*zStride3;
for(uint xh = h; xh < h + factorH; ++xh)
for(uint xw = w; xw < w + factorW; ++xw)
z[zOffset] = 0;
for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh)
for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw)
z[zOffset] += x[b*xStride0 + c*xStride1 + xh*xStride2 + xw*xStride3];
}
}
@ -1198,8 +1198,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
// input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
// output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
gradI.nullify();
const T* x = gradO.bufferAsT<T>();
T* z = gradI.bufferAsT<T>();
@ -1238,9 +1236,11 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const auto zOffset = b*zStride0 + c*zStride1 + d*zStride2 + h*zStride3 + w*zStride4;
for(uint xd = d; xd < d + factorD; ++xd)
for(uint xh = h; xh < h + factorH; ++xh)
for(uint xw = w; xw < w + factorW; ++xw)
z[zOffset] = 0;
for(uint xd = d * factorD; xd < d * factorD + factorD; ++xd)
for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh)
for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw)
z[zOffset] += x[b*xStride0 + c*xStride1 + xd*xStride2 + xh*xStride3 + xw*xStride4];
}
}

View File

@ -26,6 +26,7 @@
namespace nd4j {
namespace ops {
namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
template <typename T>
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
@ -114,7 +115,7 @@ namespace helpers {
NDArray determinant = NDArrayFactory::create<T>(1.f);
NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, input->getContext()); // has same shape as input and contiguous strides
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
permutationMatrix.setIdentity();
T pivotValue; // = T(0.0);
@ -170,7 +171,7 @@ namespace helpers {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
@ -184,6 +185,7 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
}
@ -193,7 +195,7 @@ template <typename T>
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace());
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
matrix.p(row, input->e<T>(k));
@ -220,11 +222,11 @@ template <typename T>
auto totalCount = output->lengthOf() / n2;
output->assign(0.f); // fill up output tensor with zeros
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext());
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext());
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext());
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
for (int e = 0; e < totalCount; e++) {
if (e)
@ -266,6 +268,7 @@ template <typename T>
}
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
}
@ -308,8 +311,8 @@ template <typename T>
if (!inplace)
output->assign(0.f); // fill up output tensor with zeros only inplace=false
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), input->getContext())); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), input->getContext()));
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
for (int e = 0; e < totalCount; e++) {
@ -343,6 +346,7 @@ template <typename T>
}
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
}

View File

@ -19,8 +19,6 @@
//
#include <ops/declarable/helpers/sg_cb.h>
#include <AveragingArrayProxy.h>
#include <helpers/AveragingArrayProxy.h>
#include <specials.h>
#define HS_MAX_EXP 6.0f

View File

@ -1496,8 +1496,10 @@ __global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShape
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
const Nd4jLong zCoord2 = coords[dimIH];
const Nd4jLong zCoord3 = coords[dimIH + 1];
z[zOffset] = 0;
const Nd4jLong zCoord2 = coords[dimIH] * factorH;
const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW;
for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH])
for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1])
@ -1569,9 +1571,11 @@ __global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShape
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
const Nd4jLong zCoord2 = coords[dimID];
const Nd4jLong zCoord3 = coords[dimID + 1];
const Nd4jLong zCoord4 = coords[dimID + 2];
z[zOffset] = 0;
const Nd4jLong zCoord2 = coords[dimID] * factorD;
const Nd4jLong zCoord3 = coords[dimID + 1] * factorH;
const Nd4jLong zCoord4 = coords[dimID + 2] * factorW;
for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID])
for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1])

File diff suppressed because it is too large Load Diff

View File

@ -33,19 +33,19 @@ __global__ static void zetaCuda(const void *vx, const Nd4jLong *xShapeInfo,
const auto x = reinterpret_cast<const T*>(vx);
const auto q = reinterpret_cast<const T*>(vq);
auto z = reinterpret_cast<T*>(vz);
auto z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong len;
if (threadIdx.x == 0)
len = shape::length(xShapeInfo);
if (threadIdx.x == 0)
len = shape::length(xShapeInfo);
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto totalThreads = gridDim.x * blockDim.x;
for (int i = tid; i < len; i += totalThreads) {
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, len);
const auto qOffset = shape::getIndexOffset(i, qShapeInfo, len);
const auto zOffset = shape::getIndexOffset(i, zShapeInfo, len);
@ -65,10 +65,10 @@ void zeta(nd4j::LaunchContext * context, const NDArray& x, const NDArray& q, NDA
if(!x.isActualOnDeviceSide()) x.syncToDevice();
if(!q.isActualOnDeviceSide()) q.syncToDevice();
int threadsPerBlock = MAX_NUM_THREADS;
int threadsPerBlock = MAX_NUM_THREADS / 2;
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_SINGLE_SELECTOR(x.dataType(), zetaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), q.getSpecialBuffer(), q.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
x.tickReadHost();

View File

@ -1,63 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <AveragingArrayProxy.h>
using namespace nd4j;
using namespace nd4j::ops;
using namespace nd4j::graph;
class AveragingArrayTests : public testing::Test {
public:
};
TEST_F(AveragingArrayTests, test_basic_reads_1) {
auto exp0 = NDArrayFactory::create<double>('c', {1, 5},{3.0, 3.0, 3.0, 3.0, 3.0});
auto original = NDArrayFactory::create<double>('c', {100, 5});
original.assign(1.0);
AveragingArrayProxy proxy(&original);
auto writeable0 = proxy.writeable(1, 0);
auto writeable1 = proxy.writeable(1, 1);
auto writeable2 = proxy.writeable(2, 1);
ASSERT_FALSE(writeable0 == nullptr);
ASSERT_FALSE(writeable1 == nullptr);
ASSERT_FALSE(writeable2 == nullptr);
writeable0->assign(2.0);
writeable1->assign(4.0);
writeable2->assign(3.0);
auto r = proxy.collapseWrites();
ASSERT_TRUE(r);
auto row1 = original({1,2, 0,0}, true);
auto row2 = original({2,3, 0,0}, true);
ASSERT_EQ(exp0, row1);
ASSERT_EQ(exp0, row2);
}

View File

@ -159,9 +159,9 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f});
input = 2.;
weights.linspace(0.1, 0.1);
@ -2211,55 +2211,6 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) {
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_bp_test1) {
const int bS=1, iH=2,iW=2, iC=1;
const int factorH=2, factorW=2;
const int isNCHW = 1; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW});
gradO = 1.;
auto expGradI = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
expGradI = 4.;
nd4j::ops::upsampling2d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_bp_test2) {
const int bS=1, iH=2,iW=2, iC=1;
const int factorH=2, factorW=2;
const int isNCHW = 0; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
auto gradO = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC});
gradO = 1.;
auto expGradI = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
expGradI = 4.;
nd4j::ops::upsampling2d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test1) {
@ -2315,6 +2266,70 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) {
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
const int bS=1, iD=3,iH=3,iW=3, iC=2;
const int factorD=2, factorH=2, factorW=2;
const int isNCDHW = 1; // data format, default is NCHW
NDArray input('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338,
0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668,
0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439,
0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839,
0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561,
0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631,
0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227,
0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047,
0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033,
0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843,
0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876,
0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908,
0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415,
0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304,
0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846,
0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239,
0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083,
0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075,
0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565,
0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615,
0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824,
0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565,
0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802,
0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821,
0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676,
0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527,
0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158,
0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731,
0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452,
0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176,
0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142,
0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622,
0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457,
0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804,
0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821,
0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333,
0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, nd4j::DataType::FLOAT32);
NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278,
3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016,
4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917,
4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856,
4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);
nd4j::ops::upsampling3d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test1) {

File diff suppressed because one or more lines are too long

View File

@ -640,40 +640,6 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) {
delete block;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, MergeMaxTest1) {
auto x = NDArrayFactory::create_<float>('c', {5, 5});
auto y = NDArrayFactory::create_<float>('c', {5, 5});
auto z = NDArrayFactory::create_<float>('c', {5, 5});
auto exp = NDArrayFactory::create<float>('c', {5, 5});
x->assign(3);
y->assign(1);
z->assign(2);
exp.assign(3);
auto zu = NDArrayFactory::create<float>('c', {5, 5});
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
variableSpace->putVariable(-3, z);
variableSpace->putVariable(1, new Variable(NDArrayFactory::create_<float>('c', {5, 5})));
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2, -3});
nd4j::ops::mergemax merge;
merge.execute(block);
auto res = variableSpace->getVariable(1)->getNDArray();
ASSERT_TRUE(res->equalsTo(&exp));
delete block;
delete variableSpace;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, MergeAvgTest1) {

View File

@ -845,5 +845,44 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) {
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, mergemax_1) {
NDArray x1('c', {5, 5}, nd4j::DataType::FLOAT32);
NDArray x2('c', {5, 5}, nd4j::DataType::FLOAT32);
NDArray x3('c', {5, 5}, nd4j::DataType::FLOAT32);
NDArray e('c', {5, 5}, nd4j::DataType::FLOAT32);
x1.assign(3);
x2.assign(1);
x3.assign(2);
e.assign(3);
nd4j::ops::mergemax op;
auto result = op.execute({&x1, &x2, &x3}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printBuffer();
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, mergemax_2) {
NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32);
NDArray x2('c', {1, 1}, {1.}, nd4j::DataType::FLOAT32);
NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32);
nd4j::ops::mergemax op;
auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {});
ASSERT_EQ(20, status);
}

View File

@ -1548,8 +1548,8 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
z->printIndexedBuffer("Log ABS Output ");
exp.printIndexedBuffer("Log ABS Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1671,6 +1671,40 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1., 0., 0., 0., 0.,
2., 1., 0., 0., 0.,
30., 2., 1., 0., 0.,
4., 3., 2., 1., 0.,
5., 4., 3., 2., 1.,
});
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1.0, 0.0, 0.0, 0.0, 0.,
-2.0, 1.0, 0., 0., 0.,
-26.0, -2.0, 1, 0, 0.,
54.0, 1.0, -2.0, 1, 0.,
-27.0, 0.0, 1.0, -2.0, 1.
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("010 Output ");
// exp.printIndexedBuffer("010 Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
@ -1824,7 +1858,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
auto x = NDArrayFactory::create<double>('c', {5, 5}, {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
4., 0., 0., 0., 0.,
4., 2., 0., 0., 0.,
30., 2., 1., 0., 0.,
@ -1832,7 +1866,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
15., 12., 9., 6., 3.,
});
auto exp = NDArrayFactory::create<double>('c', {5, 5}, {
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0,
@ -1841,13 +1875,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
exp.printIndexedBuffer("Expected ");
z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
// z->printIndexedBuffer("Output ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1880,8 +1914,42 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
1., 2., 30., 4., 5.,
0., 1., 2., 3., 4.,
0., 0., 1., 2., 3.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 1.
});
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0,
0.0, 1.0, -2.0, 1.0, 0.0,
0.0, 0.0, 1.0, -2.0, 1.0,
0.0, 0.0, 0.0, 1.0, -2.0,
0.0, 0.0, 0.0, 0.0, 1.0
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));

View File

@ -9363,15 +9363,28 @@ public static final int PREALLOC_SIZE = 33554432;
// #include <NDArray.h>
// #include <dll.h>
@Namespace("nd4j") @NoOffset public static class OpArgsHolder extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public OpArgsHolder(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public OpArgsHolder position(long position) {
return (OpArgsHolder)super.position(position);
}
// default constructor
public OpArgsHolder() { super((Pointer)null); allocate(); }
private native void allocate();
// copy constructor
public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); }
private native void allocate(@Const @ByRef OpArgsHolder other);
// constructor
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/);
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); }
@ -9387,6 +9400,13 @@ public static final int PREALLOC_SIZE = 33554432;
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/);
// move constructor
// assignment operator
public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other);
// move assignment operator
public native @Const @ByRef NDArrayVector getInArrs();
public native @StdVector DoublePointer getTArgs();
@ -9406,8 +9426,8 @@ public static final int PREALLOC_SIZE = 33554432;
public native int getNumBArgs();
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/);
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
}

View File

@ -9060,15 +9060,28 @@ public static final int PREALLOC_SIZE = 33554432;
// #include <NDArray.h>
// #include <dll.h>
@Namespace("nd4j") @NoOffset public static class OpArgsHolder extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public OpArgsHolder(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public OpArgsHolder position(long position) {
return (OpArgsHolder)super.position(position);
}
// default constructor
public OpArgsHolder() { super((Pointer)null); allocate(); }
private native void allocate();
// copy constructor
public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); }
private native void allocate(@Const @ByRef OpArgsHolder other);
// constructor
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/);
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); }
@ -9084,6 +9097,13 @@ public static final int PREALLOC_SIZE = 33554432;
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/);
// move constructor
// assignment operator
public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other);
// move assignment operator
public native @Const @ByRef NDArrayVector getInArrs();
public native @StdVector DoublePointer getTArgs();
@ -9103,8 +9123,8 @@ public static final int PREALLOC_SIZE = 33554432;
public native int getNumBArgs();
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/);
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
}

View File

@ -45,6 +45,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.*;
@ -548,6 +549,8 @@ public class CustomOpsTests extends BaseNd4jTest {
Nd4j.exec(op); //Execution is OK
}
@Test
public void testDepthwise(){
INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8);
@ -625,6 +628,49 @@ public class CustomOpsTests extends BaseNd4jTest {
System.out.println(out);
}
@Test
public void testUpsampling2dBackprop(){
Nd4j.getRandom().setSeed(12345);
int c = 2;
int[] sz = {2,2};
long[] inSize = {1, c, 3, 3};
INDArray eps = Nd4j.rand(DataType.FLOAT, 1, c, sz[0] * inSize[2], sz[1] * inSize[3]);
INDArray input = Nd4j.create(inSize); //Unused, not sure why this is even an arg...
INDArray exp = Nd4j.create(DataType.FLOAT, inSize);
for( int ch=0; ch<c; ch++ ) {
for( int h=0; h<eps.size(2); h++ ){
for( int w=0; w<eps.size(3); w++ ){
int[] from = new int[]{0, ch, h, w};
int[] to = new int[]{0, ch, h/sz[0], w/sz[1]};
float add = eps.getFloat(from);
float current = exp.getFloat(to);
exp.putScalar(to, current + add);
}
}
}
System.out.println("Eps:");
System.out.println(eps.shapeInfoToString());
System.out.println(Arrays.toString(eps.data().asFloat()));
System.out.println("Expected:");
System.out.println(exp.shapeInfoToString());
System.out.println(Arrays.toString(exp.data().asFloat()));
DynamicCustomOp op = DynamicCustomOp.builder("upsampling2d_bp")
.addInputs(input, eps)
.addOutputs(exp.ulike())
.addIntegerArguments(1) //1 = NCHW
.build();
Nd4j.exec(op);
INDArray act = op.getOutputArgument(0);
assertEquals(exp, act);
}
@Test
public void testIsMaxView(){