[WIP] more fixes (#159)

* Added test for MatrixInverse with double input. Fixed matrixDeterminantKernel.

* Fixed kernels to avoid waste templating.

* Fixed logDeterminant kernel.

* Refactored type check for lup'

* - decrease blockDim value for zeta op

Signed-off-by: Yurii <yurii@skymind.io>

* Added print for compound matrix with CUDA.

* Refactored upper matrix invertion kernels.

* - provide move constructor and move assignment operator for OpArgsHoder class

Signed-off-by: Yurii <yurii@skymind.io>

* Refactored usage of launch context.

* - add test for mergemax

Signed-off-by: Yurii <yurii@skymind.io>

* get rid of AveragingArrayProxy

Signed-off-by: raver119 <raver119@gmail.com>

* Refactoring of LUP inversion.

* Added prints for invertion.

* - add OpArgsHolder copy constructor and assignment operator

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for lower inversion

* - fix bug in upsampling2d/3d_bp op

Signed-off-by: Yurii <yurii@skymind.io>

* Added expensive printfs to kernel.

* Refactored expensive kernel prints.

* Refactored expensive printfs

* - remove nullify

Signed-off-by: Yurii <yurii@skymind.io>

* Eliminated waste prints with tests.

* upsampling2d_bp test

Signed-off-by: raver119 <raver119@gmail.com>

* test updated

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 19:20:50 +03:00 committed by GitHub
parent 99cdf6d42b
commit f03b0ee78f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1241 additions and 1009 deletions

View File

@ -1,59 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef DEV_TESTS_AVERAGINGARRAYPROXY_H
#define DEV_TESTS_AVERAGINGARRAYPROXY_H
#include "NDArray.h"
#include <utility>
#include <map>
#include <vector>
#include <mutex>
namespace nd4j {
class ND4J_EXPORT AveragingArrayProxy {
protected:
NDArray *_original;
std::map<std::pair<int,int>, NDArray*> _writeables;
std::map<int, std::vector<NDArray*>> _writeablesLinear;
std::vector<int> _rows;
std::vector<NDArray*> _references;
std::mutex _lock;
public:
explicit AveragingArrayProxy(NDArray *original);
~AveragingArrayProxy();
NDArray* readable(int row, int key);
NDArray* writeable(int row, int key);
bool isEmpty();
bool writeableExists(std::pair<int, int> &key);
bool writeableExists(int row, int key);
bool collapseWrites();
};
}
#endif //DEV_TESTS_AVERAGINGARRAYPROXY_H

View File

@ -30,23 +30,38 @@ namespace nd4j {
class ND4J_EXPORT OpArgsHolder { class ND4J_EXPORT OpArgsHolder {
private: private:
std::vector<NDArray*> _inArrs = std::vector<NDArray*>(); std::vector<NDArray*> _inArrs = std::vector<NDArray*>();
std::vector<double> _tArgs = std::vector<double>(); std::vector<double> _tArgs = std::vector<double>();
std::vector<Nd4jLong> _iArgs = std::vector<Nd4jLong>(); std::vector<Nd4jLong> _iArgs = std::vector<Nd4jLong>();
std::vector<bool> _bArgs = std::vector<bool>(); std::vector<bool> _bArgs = std::vector<bool>();
std::vector<bool> _isArrAlloc = std::vector<bool>();
int _numInArrs = _inArrs.size(); int _numInArrs = _inArrs.size();
int _numTArgs = _tArgs.size(); int _numTArgs = _tArgs.size();
int _numIArgs = _iArgs.size(); int _numIArgs = _iArgs.size();
int _numBArgs = _bArgs.size(); int _numBArgs = _bArgs.size();
std::vector<bool> _isArrAlloc = std::vector<bool>();
public: public:
OpArgsHolder() = delete; // default constructor
OpArgsHolder();
OpArgsHolder(const std::vector<NDArray*>& inArrs, const std::vector<double>& tArgs = std::vector<double>(), const std::vector<Nd4jLong>& iArgs = std::vector<Nd4jLong>(), const std::vector<bool>& bArgs = std::vector<bool>()) // copy constructor
: _inArrs(inArrs), _tArgs(tArgs), _iArgs(iArgs), _bArgs(bArgs) { } OpArgsHolder(const OpArgsHolder& other);
// constructor
OpArgsHolder(const std::vector<NDArray*>& inArrs, const std::vector<double>& tArgs = std::vector<double>(), const std::vector<Nd4jLong>& iArgs = std::vector<Nd4jLong>(), const std::vector<bool>& bArgs = std::vector<bool>());
// move constructor
OpArgsHolder(OpArgsHolder&& other) noexcept;
// assignment operator
OpArgsHolder& operator=(const OpArgsHolder& other);
// move assignment operator
OpArgsHolder& operator=(OpArgsHolder&& other) noexcept;
const std::vector<NDArray*>& getInArrs() const const std::vector<NDArray*>& getInArrs() const
{return _inArrs; } {return _inArrs; }

View File

@ -229,7 +229,6 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
} }
else if(ABC && aType == DataType::HALF) { else if(ABC && aType == DataType::HALF) {
printf("!!!!!!!!\n");
float16 alphaH(alpha), betaH(beta); float16 alphaH(alpha), betaH(beta);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
} }

View File

@ -1,139 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../AveragingArrayProxy.h"
namespace nd4j {
AveragingArrayProxy::AveragingArrayProxy(NDArray *original) {
_original = original;
}
AveragingArrayProxy::~AveragingArrayProxy() {
for (auto v:_references)
delete v;
}
bool AveragingArrayProxy::writeableExists(std::pair<int, int> &key) {
_lock.lock();
auto r = _writeables.count(key) > 0;
_lock.unlock();
return r;
}
bool AveragingArrayProxy::writeableExists(int row, int key) {
std::pair<int, int> k(row, key);
return writeableExists(k);
}
NDArray* AveragingArrayProxy::readable(int row, int key) {
std::pair<int, int> k(row, key);
if (writeableExists(k)) {
_lock.lock();
auto r = _writeables[k];
_lock.unlock();
return r;
} else {
auto readable = (*_original)({row,row+1, 0,0});
_lock.lock();
_references.emplace_back(&readable);
_lock.unlock();
// return readable;
}
}
bool AveragingArrayProxy::isEmpty() {
return _original->isEmpty();
}
NDArray* AveragingArrayProxy::writeable(int row, int key) {
std::pair<int, int> k(row, key);
// if writeable exists - just return it
if (writeableExists(k)) {
_lock.lock();
auto r = _writeables[k];
_lock.unlock();
return r;
} else {
// if doesn't - let's create it
auto orig = (*_original)({row,row+1, 0,0});
// we don't want views here for obvious reasons
auto writeable = orig.dup('c');
_lock.lock();
_writeables[k] = writeable;
_references.emplace_back(writeable);
// storing linear reference, for future averaging step
if (_writeablesLinear.count(row) == 0) {
std::vector<NDArray*> vec;
_writeablesLinear[row] = vec;
}
_writeablesLinear[row].emplace_back(writeable);
_rows.emplace_back(row);
_lock.unlock();
return writeable;
}
}
bool AveragingArrayProxy::collapseWrites() {
if (_writeables.empty())
return false;
for (int r = 0; r < _rows.size(); r++) {
auto row = _rows[r];
auto list = _writeablesLinear.at(row);
auto originalRow = (*_original)({row,row+1, 0,0});
if (list.size() == 1) {
originalRow.assign(list.at(0));
} else {
originalRow.assign(0.0);
for (int e = 0; e < list.size(); e++)
originalRow += *(list.at(e));
originalRow /= (int) list.size();
}
}
return true;
}
}

View File

@ -23,6 +23,101 @@
namespace nd4j { namespace nd4j {
////////////////////////////////////////////////////////////////////////
// default constructor
OpArgsHolder::OpArgsHolder() {
_inArrs = std::vector<NDArray*>();
_tArgs = std::vector<double>();
_iArgs = std::vector<Nd4jLong>();
_bArgs = std::vector<bool>();
_isArrAlloc = std::vector<bool>();
_numInArrs = 0;
_numTArgs = 0;
_numIArgs = 0;
_numBArgs = 0;
}
////////////////////////////////////////////////////////////////////////
// copy constructor
OpArgsHolder::OpArgsHolder(const OpArgsHolder& other) {
throw std::runtime_error("OpArgsHolder::OpArgsHolder copy constructor: don't use me !");
}
////////////////////////////////////////////////////////////////////////
// constructor
OpArgsHolder::OpArgsHolder(const std::vector<NDArray*>& inArrs,
const std::vector<double>& tArgs,
const std::vector<Nd4jLong>& iArgs,
const std::vector<bool>& bArgs) {
_inArrs = inArrs;
_tArgs = tArgs;
_iArgs = iArgs;
_bArgs = bArgs;
_isArrAlloc = std::vector<bool>();
_numInArrs = _inArrs.size();
_numTArgs = _tArgs.size();
_numIArgs = _iArgs.size();
_numBArgs = _bArgs.size();
}
////////////////////////////////////////////////////////////////////////
// move constructor
OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept: _inArrs(std::move(other._inArrs)),
_tArgs(std::move(other._tArgs)),
_iArgs(std::move(other._iArgs)),
_bArgs(std::move(other._bArgs)),
_isArrAlloc(std::move(other._isArrAlloc)) {
other._isArrAlloc = std::vector<bool>();
_numInArrs = _inArrs.size();
_numTArgs = _tArgs.size();
_numIArgs = _iArgs.size();
_numBArgs = _bArgs.size();
}
////////////////////////////////////////////////////////////////////////
// assignment operator
OpArgsHolder& OpArgsHolder::operator=(const OpArgsHolder& other) {
throw std::runtime_error("OpArgsHolder::OpArgsHolder assignment operator: don't use me !");
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
OpArgsHolder& OpArgsHolder::operator=(OpArgsHolder&& other) noexcept {
if (this == &other)
return *this;
for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary
if(_isArrAlloc[i])
delete _inArrs[i];
_inArrs = std::move(other._inArrs);
_tArgs = std::move(other._tArgs);
_iArgs = std::move(other._iArgs);
_bArgs = std::move(other._bArgs);
_isArrAlloc = std::move(other._isArrAlloc);
other._isArrAlloc = std::vector<bool>();
_numInArrs = _inArrs.size();
_numTArgs = _tArgs.size();
_numIArgs = _iArgs.size();
_numBArgs = _bArgs.size();
return *this;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector<NDArray*>& inGradArrs, const bool isInPlace) const { OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector<NDArray*>& inGradArrs, const bool isInPlace) const {
@ -57,7 +152,6 @@ OpArgsHolder::~OpArgsHolder() noexcept {
for (int i = 0; i < _isArrAlloc.size(); ++i) for (int i = 0; i < _isArrAlloc.size(); ++i)
if(_isArrAlloc[i]) if(_isArrAlloc[i])
delete _inArrs[i]; delete _inArrs[i];
} }
} }

View File

@ -44,7 +44,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
auto input = INPUT_VARIABLE(i); auto input = INPUT_VARIABLE(i);
auto currentRank = input->rankOf(); auto currentRank = input->rankOf();
// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy // TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank); // REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank); // REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank);
if(!input->isEmpty()) { if(!input->isEmpty()) {

View File

@ -1147,8 +1147,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
// gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
// gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
gradI.nullify();
const T* x = gradO.bufferAsT<T>(); const T* x = gradO.bufferAsT<T>();
T* z = gradI.bufferAsT<T>(); T* z = gradI.bufferAsT<T>();
@ -1182,8 +1180,10 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const auto zOffset = b*zStride0 + c*zStride1 + h*zStride2 + w*zStride3; const auto zOffset = b*zStride0 + c*zStride1 + h*zStride2 + w*zStride3;
for(uint xh = h; xh < h + factorH; ++xh) z[zOffset] = 0;
for(uint xw = w; xw < w + factorW; ++xw)
for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh)
for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw)
z[zOffset] += x[b*xStride0 + c*xStride1 + xh*xStride2 + xw*xStride3]; z[zOffset] += x[b*xStride0 + c*xStride1 + xh*xStride2 + xw*xStride3];
} }
} }
@ -1198,8 +1198,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
// input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
// output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
gradI.nullify();
const T* x = gradO.bufferAsT<T>(); const T* x = gradO.bufferAsT<T>();
T* z = gradI.bufferAsT<T>(); T* z = gradI.bufferAsT<T>();
@ -1238,9 +1236,11 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
const auto zOffset = b*zStride0 + c*zStride1 + d*zStride2 + h*zStride3 + w*zStride4; const auto zOffset = b*zStride0 + c*zStride1 + d*zStride2 + h*zStride3 + w*zStride4;
for(uint xd = d; xd < d + factorD; ++xd) z[zOffset] = 0;
for(uint xh = h; xh < h + factorH; ++xh)
for(uint xw = w; xw < w + factorW; ++xw) for(uint xd = d * factorD; xd < d * factorD + factorD; ++xd)
for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh)
for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw)
z[zOffset] += x[b*xStride0 + c*xStride1 + xd*xStride2 + xh*xStride3 + xw*xStride4]; z[zOffset] += x[b*xStride0 + c*xStride1 + xd*xStride2 + xh*xStride3 + xw*xStride4];
} }
} }

View File

@ -26,6 +26,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
template <typename T> template <typename T>
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
@ -114,7 +115,7 @@ namespace helpers {
NDArray determinant = NDArrayFactory::create<T>(1.f); NDArray determinant = NDArrayFactory::create<T>(1.f);
NDArray compoundMatrix = *input; // copy NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, input->getContext()); // has same shape as input and contiguous strides NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
permutationMatrix.setIdentity(); permutationMatrix.setIdentity();
T pivotValue; // = T(0.0); T pivotValue; // = T(0.0);
@ -170,7 +171,7 @@ namespace helpers {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
@ -184,6 +185,7 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
} }
@ -193,7 +195,7 @@ template <typename T>
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace()); NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
matrix.p(row, input->e<T>(k)); matrix.p(row, input->e<T>(k));
@ -220,11 +222,11 @@ template <typename T>
auto totalCount = output->lengthOf() / n2; auto totalCount = output->lengthOf() / n2;
output->assign(0.f); // fill up output tensor with zeros output->assign(0.f); // fill up output tensor with zeros
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); //, block.getWorkspace()); auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); //, block.getWorkspace()); auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
for (int e = 0; e < totalCount; e++) { for (int e = 0; e < totalCount; e++) {
if (e) if (e)
@ -266,6 +268,7 @@ template <typename T>
} }
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
} }
@ -308,8 +311,8 @@ template <typename T>
if (!inplace) if (!inplace)
output->assign(0.f); // fill up output tensor with zeros only inplace=false output->assign(0.f); // fill up output tensor with zeros only inplace=false
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), input->getContext())); //, block.getWorkspace()); std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), input->getContext())); std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
for (int e = 0; e < totalCount; e++) { for (int e = 0; e < totalCount; e++) {
@ -343,6 +346,7 @@ template <typename T>
} }
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
} }

View File

@ -19,8 +19,6 @@
// //
#include <ops/declarable/helpers/sg_cb.h> #include <ops/declarable/helpers/sg_cb.h>
#include <AveragingArrayProxy.h>
#include <helpers/AveragingArrayProxy.h>
#include <specials.h> #include <specials.h>
#define HS_MAX_EXP 6.0f #define HS_MAX_EXP 6.0f

View File

@ -1496,8 +1496,10 @@ __global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShape
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
const Nd4jLong zCoord2 = coords[dimIH]; z[zOffset] = 0;
const Nd4jLong zCoord3 = coords[dimIH + 1];
const Nd4jLong zCoord2 = coords[dimIH] * factorH;
const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW;
for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH]) for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH])
for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1]) for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1])
@ -1569,9 +1571,11 @@ __global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShape
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
const Nd4jLong zCoord2 = coords[dimID]; z[zOffset] = 0;
const Nd4jLong zCoord3 = coords[dimID + 1];
const Nd4jLong zCoord4 = coords[dimID + 2]; const Nd4jLong zCoord2 = coords[dimID] * factorD;
const Nd4jLong zCoord3 = coords[dimID + 1] * factorH;
const Nd4jLong zCoord4 = coords[dimID + 2] * factorW;
for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID]) for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID])
for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1]) for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1])

View File

@ -31,6 +31,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
// template <typename T> // template <typename T>
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { // static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
@ -56,7 +57,8 @@ namespace helpers {
// BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES); // BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES);
// } // }
template<typename T> template<typename T>
static __global__ void invertKernelLow(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { static __global__ void
invertKernelLow(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) {
T *inverted = reinterpret_cast<T *>(invertedBuf); T *inverted = reinterpret_cast<T *>(invertedBuf);
T *input = reinterpret_cast<T *>(inputBuf); T *input = reinterpret_cast<T *>(inputBuf);
@ -77,7 +79,8 @@ namespace helpers {
} }
template<typename T> template<typename T>
static __global__ void upvertKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { static __global__ void
upvertKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) {
T *inverted = reinterpret_cast<T *>(invertedBuf); T *inverted = reinterpret_cast<T *>(invertedBuf);
T *input = reinterpret_cast<T *>(inputBuf); T *input = reinterpret_cast<T *>(inputBuf);
@ -94,9 +97,25 @@ namespace helpers {
} }
template<typename T> template<typename T>
static __global__ void upvertKernelUp(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { static __global__ void
T* inverted = reinterpret_cast<T*>(invertedBuf); upvertKernelUp(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) {
T* input = reinterpret_cast<T*>(inputBuf);
__shared__ T* inverted;
__shared__ T* input;
__shared__ Nd4jLong* inputStride;
__shared__ Nd4jLong* invertedStride;
__shared__ Nd4jLong* invertedShapeOf;
__shared__ Nd4jLong* inputShapeOf;
if (threadIdx.x == 0) {
inverted = reinterpret_cast<T *>(invertedBuf);
input = reinterpret_cast<T *>(inputBuf);
inputStride = shape::stride(inputShape);
invertedStride = shape::stride(invertedShape);
invertedShapeOf = shape::shapeOf(invertedShape);
inputShapeOf = shape::shapeOf(inputShape);
}
__syncthreads();
auto start = threadIdx.x + blockIdx.x * blockDim.x; auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
@ -105,18 +124,19 @@ namespace helpers {
Nd4jLong pos[] = {i, i + 1}; Nd4jLong pos[] = {i, i + 1};
//Nd4jLong posY[] = {i, i}; //Nd4jLong posY[] = {i, i};
Nd4jLong posX[] = {i + 1, i + 1}; Nd4jLong posX[] = {i + 1, i + 1};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); auto xIndex = shape::getOffset(0, inputShapeOf, shape::stride(inputShape), pos, 2);
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2); // auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2);
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); // auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2); auto iIndex = shape::getOffset(0, invertedShapeOf, invertedStride, posX, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); auto zIndex = shape::getOffset(0, invertedShapeOf, invertedStride, pos, 2);
math::atomics::nd4j_atomicAdd(&inverted[zIndex], -input[xIndex] * inverted[iIndex]); // / input[yIndex]); math::atomics::nd4j_atomicAdd(&inverted[zIndex], -input[xIndex] * inverted[iIndex]); // / input[yIndex]);
//inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i) //inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)
} }
} }
template<typename T> template<typename T>
static __global__ void invertLowKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { static __global__ void
invertLowKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) {
T *inverted = reinterpret_cast<T *>(invertedBuf); T *inverted = reinterpret_cast<T *>(invertedBuf);
T *input = reinterpret_cast<T *>(inputBuf); T *input = reinterpret_cast<T *>(inputBuf);
@ -129,32 +149,50 @@ namespace helpers {
Nd4jLong posD[] = {i, i}; Nd4jLong posD[] = {i, i};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY,
2);
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ,
2);
math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex] / input[dIndex]); math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex] / input[dIndex]);
} }
} }
} }
template<typename T> template<typename T>
static __global__ void invertUpKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { static __global__ void
T* inverted = reinterpret_cast<T*>(invertedBuf);; invertUpKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) {
T* input = reinterpret_cast<T*>(inputBuf); __shared__ T* inverted;
__shared__ T* input;
__shared__ Nd4jLong* inputShapeOf;
__shared__ Nd4jLong* invertedShapeOf;
__shared__ Nd4jLong* invertedStrideOf;
__shared__ Nd4jLong* inputStrideOf;
for (int i = n - blockIdx.x - 2; i >= 0; i -= gridDim.x) { if (threadIdx.x == 0) {
for (int j = i + 2; j < n; j++) inverted = reinterpret_cast<T *>(invertedBuf);;
for (int k = i + threadIdx.x; k < n; k+= blockDim.x) { input = reinterpret_cast<T *>(inputBuf);
inputShapeOf = shape::shapeOf(inputShape);
invertedShapeOf = shape::shapeOf(invertedShape);
inputStrideOf = shape::stride(inputShape);
invertedStrideOf = shape::stride(invertedShape);
}
__syncthreads();
for (int i = (int)n - blockIdx.x - 2; i >= 0; i -= gridDim.x) {
for (int j = i + 2; j < (int)n; j++)
for (int k = i + threadIdx.x; k < (int)n; k += blockDim.x) {
Nd4jLong posZ[] = {i, j}; Nd4jLong posZ[] = {i, j};
Nd4jLong posY[] = {k, j}; Nd4jLong posY[] = {k, j};
Nd4jLong posX[] = {i, k}; Nd4jLong posX[] = {i, k};
// Nd4jLong posD[] = {i, i}; // Nd4jLong posD[] = {i, i};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); auto xIndex = shape::getOffset(0, inputShapeOf, inputStrideOf, posX, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); auto yIndex = shape::getOffset(0, invertedShapeOf, invertedStrideOf, posY, 2);
// auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); // auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); auto zIndex = shape::getOffset(0, invertedShapeOf, invertedStrideOf, posZ, 2);
math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex]);// / input[dIndex]); math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex]);// / input[dIndex]);
// printf("(%d, %d) inverted[%lld] = %lf (-inverted[%lld] * input[%lld]\n", blockIdx.x, threadIdx.x, zIndex, inverted[zIndex], yIndex, xIndex);
} }
} }
} }
@ -165,15 +203,18 @@ namespace helpers {
invertedMatrix->setIdentity(); invertedMatrix->setIdentity();
if (inputMatrix->isIdentityMatrix()) return; if (inputMatrix->isIdentityMatrix()) return;
LaunchContext* context = inputMatrix->getContext();
auto stream = context->getCudaStream(); auto stream = defaultContext->getCudaStream();
// invert main diagonal // invert main diagonal
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); upvertKernel<T> << < 1, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invert the second diagonal // invert the second diagonal
invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); invertKernelLow<T> << < 1, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); invertLowKernel<T><<< n, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
@ -186,19 +227,23 @@ namespace helpers {
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows(); int n = inputMatrix->rows();
invertedMatrix->setIdentity(); invertedMatrix->setIdentity();
auto stream = inputMatrix->getContext()->getCudaStream(); auto stream = defaultContext->getCudaStream();
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
return; return;
} }
//upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); //upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); upvertKernelUp<T><<<1, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertedMatrix->tickWriteDevice();
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
} }
@ -244,46 +289,47 @@ namespace helpers {
// } // }
// } // }
template <typename T, typename F> // template <typename T, typename F>
template<typename T>
static __global__ void determinantKernel(T *compound, T *result, Nd4jLong len) { static __global__ void determinantKernel(T *compound, T *result, Nd4jLong len) {
F tempRes = (F)result[0]; //F tempRes = result[0];
auto start = blockIdx.x * blockDim.x + threadIdx.x; auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
for (auto i = start; i < len; i += step) { for (auto i = start; i < len; i += step) {
auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2);
math::atomics::nd4j_atomicMul<F>(&tempRes, (F)compound[pos]); math::atomics::nd4j_atomicMul(&result[0], compound[pos]);
}
__syncthreads();
if (threadIdx.x == 0) {
result[0] = (T)tempRes;
} }
} }
template <typename T, typename F> template<typename T>
static __global__ void determinantLogKernel(T *compound, T *result, Nd4jLong len) { static __global__ void determinantLogKernel(T *compound, T *result, Nd4jLong len) {
F tempRes = (F)result[0]; // F tempRes = (F)result[0];
auto start = blockIdx.x * blockDim.x + threadIdx.x; auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
for (auto i = start; i < len; i += step) { for (auto i = start; i < len; i += step) {
auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2);
math::atomics::nd4j_atomicMul<F>(&tempRes, (F)compound[pos]); math::atomics::nd4j_atomicAdd(result, math::nd4j_log<T,T>(math::nd4j_abs(compound[pos])));
}
__syncthreads();
if (threadIdx.x == 0) {
result[0] = (T)math::nd4j_log<F,F>(math::nd4j_abs(tempRes));
} }
// __syncthreads();
//
// if (threadIdx.x == 0) {
// result[0] = (T)math::nd4j_log<F,F>(math::nd4j_abs(tempRes));
// }
} }
template<typename T, typename F> template<typename T, typename F>
static __global__ void fillMatrix(void* output, Nd4jLong* outShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) { static __global__ void
__shared__ F* matrix; fillMatrix(void *output, Nd4jLong *outShape, void *input, Nd4jLong *inputShape, Nd4jLong pos, Nd4jLong rowLen) {
__shared__ T* inputBuf; __shared__
__shared__ Nd4jLong inputLen; F *matrix;
__shared__ Nd4jLong n2; __shared__
T *inputBuf;
__shared__
Nd4jLong inputLen;
__shared__
Nd4jLong n2;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
matrix = reinterpret_cast<F *>(output); matrix = reinterpret_cast<F *>(output);
@ -301,15 +347,17 @@ namespace helpers {
} }
} }
template <typename T, typename F> template<typename T>
static __global__ void returnMatrix(void* output, Nd4jLong* outputShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) { static __global__ void
__shared__ F* matrix; returnMatrix(void *output, Nd4jLong *outputShape, void *input, Nd4jLong *inputShape, Nd4jLong pos,
Nd4jLong rowLen) {
__shared__ T *matrix;
__shared__ T *outputBuf; __shared__ T *outputBuf;
__shared__ Nd4jLong outputLen; __shared__ Nd4jLong outputLen;
__shared__ Nd4jLong n2; __shared__ Nd4jLong n2;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
matrix = reinterpret_cast<F*>(input); matrix = reinterpret_cast<T *>(input);
outputBuf = reinterpret_cast<T *>(output); outputBuf = reinterpret_cast<T *>(output);
outputLen = shape::length(inputShape); outputLen = shape::length(inputShape);
n2 = rowLen * rowLen; n2 = rowLen * rowLen;
@ -344,6 +392,7 @@ namespace helpers {
auto n = input->rows(); auto n = input->rows();
cusolverDnHandle_t cusolverH = nullptr; cusolverDnHandle_t cusolverH = nullptr;
cusolverStatus_t status = cusolverDnCreate(&cusolverH); cusolverStatus_t status = cusolverDnCreate(&cusolverH);
defaultContext = context;
if (CUSOLVER_STATUS_SUCCESS != status) { if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot create cuSolver handle", status); throw cuda_exception::build("Cannot create cuSolver handle", status);
} }
@ -366,7 +415,8 @@ namespace helpers {
double *d_work = nullptr; double *d_work = nullptr;
err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); err = cudaMalloc((void **) &d_work, sizeof(float) * lwork);
if (err) { if (err) {
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err); throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer",
err);
} }
double *matrix = reinterpret_cast<double *>(input->specialBuffer()); double *matrix = reinterpret_cast<double *>(input->specialBuffer());
status = cusolverDnDgetrf_bufferSize( status = cusolverDnDgetrf_bufferSize(
@ -401,12 +451,14 @@ namespace helpers {
d_work, d_work,
permutationBuf, permutationBuf,
d_info); d_info);
fillUpPermutation<double><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); fillUpPermutation<double> << < n, n, 1024, *stream >> >
(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
permutation->tickWriteDevice(); permutation->tickWriteDevice();
} }
err = cudaFree(d_work); err = cudaFree(d_work);
if (err) { if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err); throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer",
err);
} }
} }
break; break;
@ -415,7 +467,8 @@ namespace helpers {
float *d_work = nullptr; float *d_work = nullptr;
err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); err = cudaMalloc((void **) &d_work, sizeof(float) * lwork);
if (err) { if (err) {
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err); throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer",
err);
} }
status = cusolverDnSgetrf_bufferSize( status = cusolverDnSgetrf_bufferSize(
@ -451,12 +504,14 @@ namespace helpers {
d_work, d_work,
permutationBuf, permutationBuf,
d_info); d_info);
fillUpPermutation<T><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); fillUpPermutation<T> <<< n, n, 128, *stream >> >
(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
permutation->tickWriteDevice(); permutation->tickWriteDevice();
} }
err = cudaFree(d_work); err = cudaFree(d_work);
if (err) { if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err); throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer",
err);
} }
} }
@ -472,20 +527,25 @@ namespace helpers {
// NDArray::registerSpecialUse({input}, {input}); // NDArray::registerSpecialUse({input}, {input});
input->tickWriteDevice(); input->tickWriteDevice();
} }
BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_NATIVE);
BUILD_SINGLE_TEMPLATE(template void lup_,
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
FLOAT_NATIVE);
template<typename T> template<typename T>
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType(); // DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
defaultContext = context;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
defaultContext); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
@ -494,7 +554,8 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -506,7 +567,8 @@ namespace helpers {
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer()); auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset; auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); determinantKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(inputBuf, outputBuf, n);
// else // else
// determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); // determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
} }
@ -516,6 +578,7 @@ namespace helpers {
} }
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -523,26 +586,29 @@ namespace helpers {
template<typename T> template<typename T>
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType(); DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE) if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32; dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
defaultContext); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
output->assign(1.f); output->assign(0.f);
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -554,7 +620,8 @@ namespace helpers {
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer()); auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset; auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(inputBuf, outputBuf, n);
// else // else
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); // determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
} }
@ -566,23 +633,35 @@ namespace helpers {
} }
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
} }
template<typename T> template<typename T>
static __global__ void fillLowerUpperKernel(void* lowerBuf, Nd4jLong* lowerShape, void* upperBuf, Nd4jLong* upperShape, void* matrixBuf, Nd4jLong* matrixShape, Nd4jLong n) { static __global__ void
fillLowerUpperKernel(void *lowerBuf, Nd4jLong *lowerShape, void *upperBuf, Nd4jLong *upperShape,
void *matrixBuf, Nd4jLong *matrixShape, Nd4jLong n) {
__shared__ Nd4jLong* xShapeOf; __shared__
__shared__ Nd4jLong* yShapeOf; Nd4jLong *xShapeOf;
__shared__ Nd4jLong* zShapeOf; __shared__
__shared__ Nd4jLong* xStrideOf; Nd4jLong *yShapeOf;
__shared__ Nd4jLong* yStrideOf; __shared__
__shared__ Nd4jLong* zStrideOf; Nd4jLong *zShapeOf;
__shared__ T* lowerMatrix; __shared__
__shared__ T* upperMatrix; Nd4jLong *xStrideOf;
__shared__ T* matrix; __shared__
Nd4jLong *yStrideOf;
__shared__
Nd4jLong *zStrideOf;
__shared__
T *lowerMatrix;
__shared__
T *upperMatrix;
__shared__
T *matrix;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
xShapeOf = shape::shapeOf(lowerShape); xShapeOf = shape::shapeOf(lowerShape);
@ -617,38 +696,56 @@ namespace helpers {
template<typename T> template<typename T>
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
auto dtype = input->dataType(); auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1}); {input->rankOf() - 2,
input->rankOf() - 1});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(),
{output->rankOf() - 2,
output->rankOf() - 1});
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
for (auto i = 0LL; i < packX.numberOfTads(); i++) { for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); fillMatrix<T, T> << < 1, n2, 1024, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
i * n2, n);
matrix.tickWriteDevice(); matrix.tickWriteDevice();
compound.assign(matrix); compound.assign(matrix);
lup_<T>(context, &compound, nullptr, nullptr); lup_<T>(context, &compound, nullptr, nullptr);
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); fillLowerUpperKernel<T> << < n, n, 1024, *stream >> >
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
matrix.assign(0); matrix.assign(0);
invertUpperMatrix(&upper, &matrix); // U^{-1} invertUpperMatrix(&upper, &matrix); // U^{-1}
matrix.tickWriteDevice();
// matrix.printIndexedBuffer("Upper Inverted");
compound.assign(0); compound.assign(0);
invertLowerMatrix(&lower, &compound); // L{-1} invertLowerMatrix(&lower, &compound); // L{-1}
compound.tickWriteDevice();
// compound.printIndexedBuffer("Lower Inverted");
// matrix.tickWriteDevice();
// compound.tickWriteDevice();
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
returnMatrix<T, T><<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); upper.tickWriteDevice();
// upper.printIndexedBuffer("Full inverted");
returnMatrix<T> << < 1, n2, 1024, *stream >> >
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
i * n2, n);
} }
return Status::OK(); return Status::OK();
} }
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -669,7 +766,8 @@ namespace helpers {
} }
template<typename F> template<typename F>
__global__ void adjustResultsKernel(F* dArray, Nd4jLong* shape, Nd4jLong* offsets, Nd4jLong batchSize, Nd4jLong n) { __global__ void
adjustResultsKernel(F *dArray, Nd4jLong *shape, Nd4jLong *offsets, Nd4jLong batchSize, Nd4jLong n) {
//auto i = blockIdx.x * blockDim.x + threadIdx.x; //auto i = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong *shapeOf = shape::shapeOf(shape); Nd4jLong *shapeOf = shape::shapeOf(shape);
Nd4jLong *strideOf = shape::stride(shape); Nd4jLong *strideOf = shape::stride(shape);
@ -690,6 +788,7 @@ namespace helpers {
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
if (!inplace) if (!inplace)
output->assign(input); output->assign(input);
defaultContext = context;
std::unique_ptr<NDArray> tempOutput(output->dup()); std::unique_ptr<NDArray> tempOutput(output->dup());
cusolverDnHandle_t handle = nullptr; cusolverDnHandle_t handle = nullptr;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
@ -700,19 +799,23 @@ namespace helpers {
throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status); throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status);
} }
F **dArrayBatch = nullptr; F **dArrayBatch = nullptr;
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {tempOutput->rankOf() - 2, tempOutput->rankOf() - 1}); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(),
{tempOutput->rankOf() - 2,
tempOutput->rankOf() - 1});
const Nd4jLong batchSize = packX.numberOfTads(); const Nd4jLong batchSize = packX.numberOfTads();
int *dInfoArray = nullptr; int *dInfoArray = nullptr;
auto err = cudaMalloc((void **) &dArrayBatch, sizeof(F *) * batchSize); auto err = cudaMalloc((void **) &dArrayBatch, sizeof(F *) * batchSize);
if (err) { if (err) {
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer", err); throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer",
err);
} }
err = cudaMalloc((void **) &dInfoArray, sizeof(int) * batchSize); err = cudaMalloc((void **) &dInfoArray, sizeof(int) * batchSize);
if (err) { if (err) {
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err);
} }
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
fillBatchKernel<F><<<1, batchSize, 128, *stream>>>(dArrayBatch, reinterpret_cast<F*>(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize); fillBatchKernel<F> << < 1, batchSize, 128, *stream >> >
(dArrayBatch, reinterpret_cast<F *>(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize);
status = cusolverDnSetStream(handle, *stream); status = cusolverDnSetStream(handle, *stream);
if (CUSOLVER_STATUS_SUCCESS != status) { if (CUSOLVER_STATUS_SUCCESS != status) {
@ -741,11 +844,13 @@ namespace helpers {
if (CUSOLVER_STATUS_SUCCESS != status) { if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status); throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status);
} }
adjustResultsKernel<F><<<batchSize, n2, 128, *stream>>>(reinterpret_cast<F*>(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); adjustResultsKernel<F> << < batchSize, n2, 128, *stream >> >
(reinterpret_cast<F *>(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n);
err = cudaFree(dArrayBatch); err = cudaFree(dArrayBatch);
if (err) { if (err) {
throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer", err); throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer",
err);
} }
err = cudaFree(dInfoArray); err = cudaFree(dInfoArray);
if (err) { if (err) {
@ -763,13 +868,16 @@ namespace helpers {
// template <typename T> // template <typename T>
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
if (input->dataType() == DataType::DOUBLE) if (input->dataType() == DataType::DOUBLE)
cholesky__<double>(context, input, output, inplace); cholesky__<double>(context, input, output, inplace);
else if (input->dataType() == DataType::FLOAT32) else if (input->dataType() == DataType::FLOAT32)
cholesky__<float>(context, input, output, inplace); cholesky__<float>(context, input, output, inplace);
else { else {
std::unique_ptr<NDArray> tempOutput(NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, input->getContext())); std::unique_ptr<NDArray> tempOutput(
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
defaultContext));
tempOutput->assign(input); tempOutput->assign(input);
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true); cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
output->assign(tempOutput.get()); output->assign(tempOutput.get());
@ -780,12 +888,17 @@ namespace helpers {
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); // BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
defaultContext = context;
return cholesky_(context, input, output, inplace); return cholesky_(context, input, output, inplace);
} }
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); // BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext * context, NDArray * input, NDArray * output),
FLOAT_NATIVE);
__global__ void logDetKernel(double* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, double* outputBuf, Nd4jLong* outputShape) { template<typename T>
__global__ void
logDetKernel(T *inputBuf, Nd4jLong *inputShape, Nd4jLong batchNum, Nd4jLong *tadShape, Nd4jLong *tadOffsets,
T *outputBuf, Nd4jLong *outputShape) {
__shared__ int n; __shared__ int n;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@ -793,25 +906,28 @@ namespace helpers {
} }
__syncthreads(); __syncthreads();
double* output = outputBuf; T *output = outputBuf;
double* input = inputBuf; T *input = inputBuf;
Nd4jLong *shapeOf = shape::shapeOf(tadShape); Nd4jLong *shapeOf = shape::shapeOf(tadShape);
Nd4jLong *strideOf = shape::stride(tadShape); Nd4jLong *strideOf = shape::stride(tadShape);
for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) {
double* current = input + tadOffsets[i]; T *current = input + tadOffsets[i];
auto zIndex = shape::getIndexOffset(i, outputShape, batchNum); auto zIndex = shape::getIndexOffset(i, outputShape, batchNum);
for (auto e = threadIdx.x; e < n; e += blockDim.x) { for (auto e = threadIdx.x; e < n; e += blockDim.x) {
Nd4jLong diag[] = {e, e}; Nd4jLong diag[] = {e, e};
auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2); auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2);
math::atomics::nd4j_atomicAdd(&output[zIndex], math::nd4j_log<double,double>(current[xIndex] * current[xIndex])); math::atomics::nd4j_atomicAdd(&output[zIndex],
math::nd4j_log<T, T>(current[xIndex] * current[xIndex]));
} }
} }
} }
int logdetFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { template<typename T>
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
auto n2 = input->sizeAt(-1) * input->sizeAt(-2); auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
@ -825,17 +941,28 @@ namespace helpers {
cholesky(context, input, tempOutput.get(), false); cholesky(context, input, tempOutput.get(), false);
tempOutput->syncToHost(); tempOutput->syncToHost();
tempOutput->printIndexedBuffer("Cholesky res!!!"); tempOutput->printIndexedBuffer("Cholesky res!!!");
auto outputBuf = reinterpret_cast<double*>(output->specialBuffer()); // + e * n2; // + e * n2; auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()); // + e * n2; // + e * n2;
auto inputBuf = reinterpret_cast<double*>(tempOutput->specialBuffer()); auto inputBuf = reinterpret_cast<T*>(tempOutput->specialBuffer());
output->assign(0); output->assign(0);
output->syncToDevice(); output->syncToDevice();
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(),
logDetKernel<<<packX.numberOfTads(), n2, 128, *stream>>>(inputBuf, tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo()); {input->rankOf() - 2,
input->rankOf() - 1});
logDetKernel<T> << < packX.numberOfTads(), n2, 128, *stream >> >
(inputBuf, tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo());
// } // }
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
//delete tempOutput; //delete tempOutput;
return Status::OK(); return Status::OK();
} }
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
}
BUILD_SINGLE_TEMPLATE(template int logdetFunctor_,
(nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE);
} }
} }
} }

View File

@ -66,7 +66,7 @@ void zeta(nd4j::LaunchContext * context, const NDArray& x, const NDArray& q, NDA
if(!x.isActualOnDeviceSide()) x.syncToDevice(); if(!x.isActualOnDeviceSide()) x.syncToDevice();
if(!q.isActualOnDeviceSide()) q.syncToDevice(); if(!q.isActualOnDeviceSide()) q.syncToDevice();
int threadsPerBlock = MAX_NUM_THREADS; int threadsPerBlock = MAX_NUM_THREADS / 2;
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_SINGLE_SELECTOR(x.dataType(), zetaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), q.getSpecialBuffer(), q.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(x.dataType(), zetaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), q.getSpecialBuffer(), q.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);

View File

@ -1,63 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <AveragingArrayProxy.h>
using namespace nd4j;
using namespace nd4j::ops;
using namespace nd4j::graph;
class AveragingArrayTests : public testing::Test {
public:
};
TEST_F(AveragingArrayTests, test_basic_reads_1) {
auto exp0 = NDArrayFactory::create<double>('c', {1, 5},{3.0, 3.0, 3.0, 3.0, 3.0});
auto original = NDArrayFactory::create<double>('c', {100, 5});
original.assign(1.0);
AveragingArrayProxy proxy(&original);
auto writeable0 = proxy.writeable(1, 0);
auto writeable1 = proxy.writeable(1, 1);
auto writeable2 = proxy.writeable(2, 1);
ASSERT_FALSE(writeable0 == nullptr);
ASSERT_FALSE(writeable1 == nullptr);
ASSERT_FALSE(writeable2 == nullptr);
writeable0->assign(2.0);
writeable1->assign(4.0);
writeable2->assign(3.0);
auto r = proxy.collapseWrites();
ASSERT_TRUE(r);
auto row1 = original({1,2, 0,0}, true);
auto row2 = original({2,3, 0,0}, true);
ASSERT_EQ(exp0, row1);
ASSERT_EQ(exp0, row2);
}

View File

@ -2211,55 +2211,6 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_bp_test1) {
const int bS=1, iH=2,iW=2, iC=1;
const int factorH=2, factorW=2;
const int isNCHW = 1; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW});
gradO = 1.;
auto expGradI = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
expGradI = 4.;
nd4j::ops::upsampling2d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_bp_test2) {
const int bS=1, iH=2,iW=2, iC=1;
const int factorH=2, factorW=2;
const int isNCHW = 0; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
auto gradO = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC});
gradO = 1.;
auto expGradI = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
expGradI = 4.;
nd4j::ops::upsampling2d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test1) { TEST_F(ConvolutionTests1, upsampling3d_bp_test1) {
@ -2315,6 +2266,70 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
const int bS=1, iD=3,iH=3,iW=3, iC=2;
const int factorD=2, factorH=2, factorW=2;
const int isNCDHW = 1; // data format, default is NCHW
NDArray input('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338,
0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668,
0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439,
0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839,
0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561,
0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631,
0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227,
0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047,
0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033,
0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843,
0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876,
0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908,
0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415,
0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304,
0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846,
0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239,
0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083,
0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075,
0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565,
0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615,
0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824,
0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565,
0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802,
0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821,
0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676,
0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527,
0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158,
0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731,
0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452,
0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176,
0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142,
0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622,
0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457,
0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804,
0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821,
0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333,
0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, nd4j::DataType::FLOAT32);
NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278,
3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016,
4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917,
4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856,
4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);
nd4j::ops::upsampling3d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test1) { TEST_F(ConvolutionTests1, deconv2d_test1) {

View File

@ -1967,6 +1967,84 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) {
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, upsampling2d_bp_1) {
const int bS=1, iH=2,iW=2, iC=1;
const int factorH=2, factorW=2;
const int isNCHW = 1; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW});
gradO = 1.;
auto expGradI = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
expGradI = 4.;
nd4j::ops::upsampling2d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, upsampling2d_bp_2) {
const int bS=1, iH=2,iW=2, iC=1;
const int factorH=2, factorW=2;
const int isNCHW = 0; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
auto gradO = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC});
gradO = 1.;
auto expGradI = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
expGradI = 4.;
nd4j::ops::upsampling2d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, upsampling2d_bp_3) {
const int bS=1, iH=3,iW=3, iC=2;
const int factorH=2, factorW=2;
const int isNCHW = 1; // data format, default is NCHW
NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {bS, iC, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, 0.44793984,
0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, 0.13505761,
0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, 0.32870287,
0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108,
0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, 0.6994972,
0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, 0.5277549,
0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397}, nd4j::DataType::FLOAT32);
NDArray expGradI('c', {bS, iC, iH, iW}, {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, 1.0523045, 1.3174672, 1.9263644,
1.090545, 1.9094483, 1.3611296, 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, nd4j::DataType::FLOAT32);
nd4j::ops::upsampling2d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
#endif //LIBND4J_CONVOLUTIONTESTS2_H #endif //LIBND4J_CONVOLUTIONTESTS2_H

View File

@ -640,40 +640,6 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) {
delete block; delete block;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, MergeMaxTest1) {
auto x = NDArrayFactory::create_<float>('c', {5, 5});
auto y = NDArrayFactory::create_<float>('c', {5, 5});
auto z = NDArrayFactory::create_<float>('c', {5, 5});
auto exp = NDArrayFactory::create<float>('c', {5, 5});
x->assign(3);
y->assign(1);
z->assign(2);
exp.assign(3);
auto zu = NDArrayFactory::create<float>('c', {5, 5});
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
variableSpace->putVariable(-3, z);
variableSpace->putVariable(1, new Variable(NDArrayFactory::create_<float>('c', {5, 5})));
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2, -3});
nd4j::ops::mergemax merge;
merge.execute(block);
auto res = variableSpace->getVariable(1)->getNDArray();
ASSERT_TRUE(res->equalsTo(&exp));
delete block;
delete variableSpace;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, MergeAvgTest1) { TEST_F(DeclarableOpsTests1, MergeAvgTest1) {

View File

@ -845,5 +845,44 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) {
delete result; delete result;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, mergemax_1) {
NDArray x1('c', {5, 5}, nd4j::DataType::FLOAT32);
NDArray x2('c', {5, 5}, nd4j::DataType::FLOAT32);
NDArray x3('c', {5, 5}, nd4j::DataType::FLOAT32);
NDArray e('c', {5, 5}, nd4j::DataType::FLOAT32);
x1.assign(3);
x2.assign(1);
x3.assign(2);
e.assign(3);
nd4j::ops::mergemax op;
auto result = op.execute({&x1, &x2, &x3}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printBuffer();
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, mergemax_2) {
NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32);
NDArray x2('c', {1, 1}, {1.}, nd4j::DataType::FLOAT32);
NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32);
nd4j::ops::mergemax op;
auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {});
ASSERT_EQ(20, status);
}

View File

@ -1548,8 +1548,8 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
//z->printIndexedBuffer("Output "); z->printIndexedBuffer("Log ABS Output ");
//exp.printIndexedBuffer("Expected "); exp.printIndexedBuffer("Log ABS Expected ");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -1671,6 +1671,40 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1., 0., 0., 0., 0.,
2., 1., 0., 0., 0.,
30., 2., 1., 0., 0.,
4., 3., 2., 1., 0.,
5., 4., 3., 2., 1.,
});
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1.0, 0.0, 0.0, 0.0, 0.,
-2.0, 1.0, 0., 0., 0.,
-26.0, -2.0, 1, 0, 0.,
54.0, 1.0, -2.0, 1, 0.,
-27.0, 0.0, 1.0, -2.0, 1.
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("010 Output ");
// exp.printIndexedBuffer("010 Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_01) { TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
@ -1824,7 +1858,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_3) { TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
auto x = NDArrayFactory::create<double>('c', {5, 5}, { auto x = NDArrayFactory::create<float>('c', {5, 5}, {
4., 0., 0., 0., 0., 4., 0., 0., 0., 0.,
4., 2., 0., 0., 0., 4., 2., 0., 0., 0.,
30., 2., 1., 0., 0., 30., 2., 1., 0., 0.,
@ -1832,7 +1866,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
15., 12., 9., 6., 3., 15., 12., 9., 6., 3.,
}); });
auto exp = NDArrayFactory::create<double>('c', {5, 5}, { auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0, -0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0, -6.50, -1.0, 1.0, 0.0, 0.0,
@ -1841,13 +1875,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
exp.printIndexedBuffer("Expected "); // exp.printIndexedBuffer("Expected ");
z->printIndexedBuffer("Output "); // z->printIndexedBuffer("Output ");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -1880,8 +1914,42 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
z->printIndexedBuffer("Output "); // z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected "); // exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
1., 2., 30., 4., 5.,
0., 1., 2., 3., 4.,
0., 0., 1., 2., 3.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 1.
});
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0,
0.0, 1.0, -2.0, 1.0, 0.0,
0.0, 0.0, 1.0, -2.0, 1.0,
0.0, 0.0, 0.0, 1.0, -2.0,
0.0, 0.0, 0.0, 0.0, 1.0
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));

View File

@ -9368,10 +9368,23 @@ public static final int PREALLOC_SIZE = 33554432;
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public OpArgsHolder(Pointer p) { super(p); } public OpArgsHolder(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public OpArgsHolder position(long position) {
return (OpArgsHolder)super.position(position);
}
// default constructor
public OpArgsHolder() { super((Pointer)null); allocate(); }
private native void allocate();
// copy constructor
public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); }
private native void allocate(@Const @ByRef OpArgsHolder other);
// constructor
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/); private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/);
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); }
@ -9387,6 +9400,13 @@ public static final int PREALLOC_SIZE = 33554432;
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/); private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/);
// move constructor
// assignment operator
public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other);
// move assignment operator
public native @Const @ByRef NDArrayVector getInArrs(); public native @Const @ByRef NDArrayVector getInArrs();
public native @StdVector DoublePointer getTArgs(); public native @StdVector DoublePointer getTArgs();

View File

@ -9065,10 +9065,23 @@ public static final int PREALLOC_SIZE = 33554432;
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public OpArgsHolder(Pointer p) { super(p); } public OpArgsHolder(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public OpArgsHolder position(long position) {
return (OpArgsHolder)super.position(position);
}
// default constructor
public OpArgsHolder() { super((Pointer)null); allocate(); }
private native void allocate();
// copy constructor
public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); }
private native void allocate(@Const @ByRef OpArgsHolder other);
// constructor
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/); private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/);
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); }
@ -9084,6 +9097,13 @@ public static final int PREALLOC_SIZE = 33554432;
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/); private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/);
// move constructor
// assignment operator
public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other);
// move assignment operator
public native @Const @ByRef NDArrayVector getInArrs(); public native @Const @ByRef NDArrayVector getInArrs();
public native @StdVector DoublePointer getTArgs(); public native @StdVector DoublePointer getTArgs();

View File

@ -45,6 +45,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -548,6 +549,8 @@ public class CustomOpsTests extends BaseNd4jTest {
Nd4j.exec(op); //Execution is OK Nd4j.exec(op); //Execution is OK
} }
@Test @Test
public void testDepthwise(){ public void testDepthwise(){
INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8); INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8);
@ -625,6 +628,49 @@ public class CustomOpsTests extends BaseNd4jTest {
System.out.println(out); System.out.println(out);
} }
@Test
public void testUpsampling2dBackprop(){
Nd4j.getRandom().setSeed(12345);
int c = 2;
int[] sz = {2,2};
long[] inSize = {1, c, 3, 3};
INDArray eps = Nd4j.rand(DataType.FLOAT, 1, c, sz[0] * inSize[2], sz[1] * inSize[3]);
INDArray input = Nd4j.create(inSize); //Unused, not sure why this is even an arg...
INDArray exp = Nd4j.create(DataType.FLOAT, inSize);
for( int ch=0; ch<c; ch++ ) {
for( int h=0; h<eps.size(2); h++ ){
for( int w=0; w<eps.size(3); w++ ){
int[] from = new int[]{0, ch, h, w};
int[] to = new int[]{0, ch, h/sz[0], w/sz[1]};
float add = eps.getFloat(from);
float current = exp.getFloat(to);
exp.putScalar(to, current + add);
}
}
}
System.out.println("Eps:");
System.out.println(eps.shapeInfoToString());
System.out.println(Arrays.toString(eps.data().asFloat()));
System.out.println("Expected:");
System.out.println(exp.shapeInfoToString());
System.out.println(Arrays.toString(exp.data().asFloat()));
DynamicCustomOp op = DynamicCustomOp.builder("upsampling2d_bp")
.addInputs(input, eps)
.addOutputs(exp.ulike())
.addIntegerArguments(1) //1 = NCHW
.build();
Nd4j.exec(op);
INDArray act = op.getOutputArgument(0);
assertEquals(exp, act);
}
@Test @Test
public void testIsMaxView(){ public void testIsMaxView(){