[WIP] more fixes (#159)
* Added test for MatrixInverse with double input. Fixed matrixDeterminantKernel. * Fixed kernels to avoid waste templating. * Fixed logDeterminant kernel. * Refactored type check for lup' * - decrease blockDim value for zeta op Signed-off-by: Yurii <yurii@skymind.io> * Added print for compound matrix with CUDA. * Refactored upper matrix invertion kernels. * - provide move constructor and move assignment operator for OpArgsHoder class Signed-off-by: Yurii <yurii@skymind.io> * Refactored usage of launch context. * - add test for mergemax Signed-off-by: Yurii <yurii@skymind.io> * get rid of AveragingArrayProxy Signed-off-by: raver119 <raver119@gmail.com> * Refactoring of LUP inversion. * Added prints for invertion. * - add OpArgsHolder copy constructor and assignment operator Signed-off-by: Yurii <yurii@skymind.io> * Added test for lower inversion * - fix bug in upsampling2d/3d_bp op Signed-off-by: Yurii <yurii@skymind.io> * Added expensive printfs to kernel. * Refactored expensive kernel prints. * Refactored expensive printfs * - remove nullify Signed-off-by: Yurii <yurii@skymind.io> * Eliminated waste prints with tests. * upsampling2d_bp test Signed-off-by: raver119 <raver119@gmail.com> * test updated Signed-off-by: raver119 <raver119@gmail.com>master
parent
99cdf6d42b
commit
f03b0ee78f
|
@ -1,59 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
|
||||
#ifndef DEV_TESTS_AVERAGINGARRAYPROXY_H
|
||||
#define DEV_TESTS_AVERAGINGARRAYPROXY_H
|
||||
|
||||
#include "NDArray.h"
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT AveragingArrayProxy {
|
||||
protected:
|
||||
NDArray *_original;
|
||||
|
||||
std::map<std::pair<int,int>, NDArray*> _writeables;
|
||||
std::map<int, std::vector<NDArray*>> _writeablesLinear;
|
||||
std::vector<int> _rows;
|
||||
|
||||
std::vector<NDArray*> _references;
|
||||
|
||||
std::mutex _lock;
|
||||
public:
|
||||
explicit AveragingArrayProxy(NDArray *original);
|
||||
~AveragingArrayProxy();
|
||||
|
||||
NDArray* readable(int row, int key);
|
||||
NDArray* writeable(int row, int key);
|
||||
|
||||
bool isEmpty();
|
||||
|
||||
bool writeableExists(std::pair<int, int> &key);
|
||||
bool writeableExists(int row, int key);
|
||||
|
||||
bool collapseWrites();
|
||||
};
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_AVERAGINGARRAYPROXY_H
|
|
@ -26,27 +26,42 @@
|
|||
#include <dll.h>
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
|
||||
class ND4J_EXPORT OpArgsHolder {
|
||||
|
||||
private:
|
||||
private:
|
||||
|
||||
std::vector<NDArray*> _inArrs = std::vector<NDArray*>();
|
||||
std::vector<double> _tArgs = std::vector<double>();
|
||||
std::vector<Nd4jLong> _iArgs = std::vector<Nd4jLong>();
|
||||
std::vector<bool> _bArgs = std::vector<bool>();
|
||||
std::vector<double> _tArgs = std::vector<double>();
|
||||
std::vector<Nd4jLong> _iArgs = std::vector<Nd4jLong>();
|
||||
std::vector<bool> _bArgs = std::vector<bool>();
|
||||
|
||||
std::vector<bool> _isArrAlloc = std::vector<bool>();
|
||||
|
||||
int _numInArrs = _inArrs.size();
|
||||
int _numTArgs = _tArgs.size();
|
||||
int _numIArgs = _iArgs.size();
|
||||
int _numBArgs = _bArgs.size();
|
||||
std::vector<bool> _isArrAlloc = std::vector<bool>();
|
||||
|
||||
public:
|
||||
|
||||
OpArgsHolder() = delete;
|
||||
// default constructor
|
||||
OpArgsHolder();
|
||||
|
||||
OpArgsHolder(const std::vector<NDArray*>& inArrs, const std::vector<double>& tArgs = std::vector<double>(), const std::vector<Nd4jLong>& iArgs = std::vector<Nd4jLong>(), const std::vector<bool>& bArgs = std::vector<bool>())
|
||||
: _inArrs(inArrs), _tArgs(tArgs), _iArgs(iArgs), _bArgs(bArgs) { }
|
||||
// copy constructor
|
||||
OpArgsHolder(const OpArgsHolder& other);
|
||||
|
||||
// constructor
|
||||
OpArgsHolder(const std::vector<NDArray*>& inArrs, const std::vector<double>& tArgs = std::vector<double>(), const std::vector<Nd4jLong>& iArgs = std::vector<Nd4jLong>(), const std::vector<bool>& bArgs = std::vector<bool>());
|
||||
|
||||
// move constructor
|
||||
OpArgsHolder(OpArgsHolder&& other) noexcept;
|
||||
|
||||
// assignment operator
|
||||
OpArgsHolder& operator=(const OpArgsHolder& other);
|
||||
|
||||
// move assignment operator
|
||||
OpArgsHolder& operator=(OpArgsHolder&& other) noexcept;
|
||||
|
||||
const std::vector<NDArray*>& getInArrs() const
|
||||
{return _inArrs; }
|
||||
|
@ -77,8 +92,8 @@ public:
|
|||
|
||||
OpArgsHolder createArgsHolderForBP(const std::vector<NDArray*>& inGradArrs, const bool isInPlace = false) const;
|
||||
|
||||
~OpArgsHolder() noexcept;
|
||||
|
||||
~OpArgsHolder() noexcept;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -229,7 +229,6 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
|
||||
}
|
||||
else if(ABC && aType == DataType::HALF) {
|
||||
printf("!!!!!!!!\n");
|
||||
float16 alphaH(alpha), betaH(beta);
|
||||
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
|
||||
}
|
||||
|
|
|
@ -1,139 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "../AveragingArrayProxy.h"
|
||||
|
||||
namespace nd4j {
|
||||
AveragingArrayProxy::AveragingArrayProxy(NDArray *original) {
|
||||
_original = original;
|
||||
}
|
||||
|
||||
AveragingArrayProxy::~AveragingArrayProxy() {
|
||||
for (auto v:_references)
|
||||
delete v;
|
||||
}
|
||||
|
||||
bool AveragingArrayProxy::writeableExists(std::pair<int, int> &key) {
|
||||
_lock.lock();
|
||||
|
||||
auto r = _writeables.count(key) > 0;
|
||||
|
||||
_lock.unlock();
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
bool AveragingArrayProxy::writeableExists(int row, int key) {
|
||||
std::pair<int, int> k(row, key);
|
||||
return writeableExists(k);
|
||||
}
|
||||
|
||||
NDArray* AveragingArrayProxy::readable(int row, int key) {
|
||||
std::pair<int, int> k(row, key);
|
||||
|
||||
if (writeableExists(k)) {
|
||||
_lock.lock();
|
||||
|
||||
auto r = _writeables[k];
|
||||
|
||||
_lock.unlock();
|
||||
|
||||
return r;
|
||||
} else {
|
||||
auto readable = (*_original)({row,row+1, 0,0});
|
||||
|
||||
_lock.lock();
|
||||
|
||||
_references.emplace_back(&readable);
|
||||
|
||||
_lock.unlock();
|
||||
|
||||
// return readable;
|
||||
}
|
||||
}
|
||||
|
||||
bool AveragingArrayProxy::isEmpty() {
|
||||
return _original->isEmpty();
|
||||
}
|
||||
|
||||
NDArray* AveragingArrayProxy::writeable(int row, int key) {
|
||||
std::pair<int, int> k(row, key);
|
||||
|
||||
// if writeable exists - just return it
|
||||
if (writeableExists(k)) {
|
||||
_lock.lock();
|
||||
|
||||
auto r = _writeables[k];
|
||||
|
||||
_lock.unlock();
|
||||
|
||||
return r;
|
||||
} else {
|
||||
// if doesn't - let's create it
|
||||
auto orig = (*_original)({row,row+1, 0,0});
|
||||
|
||||
// we don't want views here for obvious reasons
|
||||
auto writeable = orig.dup('c');
|
||||
|
||||
_lock.lock();
|
||||
|
||||
_writeables[k] = writeable;
|
||||
_references.emplace_back(writeable);
|
||||
|
||||
// storing linear reference, for future averaging step
|
||||
if (_writeablesLinear.count(row) == 0) {
|
||||
std::vector<NDArray*> vec;
|
||||
_writeablesLinear[row] = vec;
|
||||
}
|
||||
|
||||
_writeablesLinear[row].emplace_back(writeable);
|
||||
_rows.emplace_back(row);
|
||||
|
||||
_lock.unlock();
|
||||
|
||||
return writeable;
|
||||
}
|
||||
}
|
||||
|
||||
bool AveragingArrayProxy::collapseWrites() {
|
||||
if (_writeables.empty())
|
||||
return false;
|
||||
|
||||
for (int r = 0; r < _rows.size(); r++) {
|
||||
auto row = _rows[r];
|
||||
auto list = _writeablesLinear.at(row);
|
||||
|
||||
auto originalRow = (*_original)({row,row+1, 0,0});
|
||||
|
||||
if (list.size() == 1) {
|
||||
originalRow.assign(list.at(0));
|
||||
} else {
|
||||
originalRow.assign(0.0);
|
||||
|
||||
for (int e = 0; e < list.size(); e++)
|
||||
originalRow += *(list.at(e));
|
||||
|
||||
originalRow /= (int) list.size();
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -23,27 +23,122 @@
|
|||
|
||||
namespace nd4j {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// default constructor
|
||||
OpArgsHolder::OpArgsHolder() {
|
||||
|
||||
_inArrs = std::vector<NDArray*>();
|
||||
_tArgs = std::vector<double>();
|
||||
_iArgs = std::vector<Nd4jLong>();
|
||||
_bArgs = std::vector<bool>();
|
||||
|
||||
_isArrAlloc = std::vector<bool>();
|
||||
|
||||
_numInArrs = 0;
|
||||
_numTArgs = 0;
|
||||
_numIArgs = 0;
|
||||
_numBArgs = 0;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// copy constructor
|
||||
OpArgsHolder::OpArgsHolder(const OpArgsHolder& other) {
|
||||
|
||||
throw std::runtime_error("OpArgsHolder::OpArgsHolder copy constructor: don't use me !");
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// constructor
|
||||
OpArgsHolder::OpArgsHolder(const std::vector<NDArray*>& inArrs,
|
||||
const std::vector<double>& tArgs,
|
||||
const std::vector<Nd4jLong>& iArgs,
|
||||
const std::vector<bool>& bArgs) {
|
||||
_inArrs = inArrs;
|
||||
_tArgs = tArgs;
|
||||
_iArgs = iArgs;
|
||||
_bArgs = bArgs;
|
||||
|
||||
_isArrAlloc = std::vector<bool>();
|
||||
|
||||
_numInArrs = _inArrs.size();
|
||||
_numTArgs = _tArgs.size();
|
||||
_numIArgs = _iArgs.size();
|
||||
_numBArgs = _bArgs.size();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// move constructor
|
||||
OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept: _inArrs(std::move(other._inArrs)),
|
||||
_tArgs(std::move(other._tArgs)),
|
||||
_iArgs(std::move(other._iArgs)),
|
||||
_bArgs(std::move(other._bArgs)),
|
||||
_isArrAlloc(std::move(other._isArrAlloc)) {
|
||||
|
||||
other._isArrAlloc = std::vector<bool>();
|
||||
|
||||
_numInArrs = _inArrs.size();
|
||||
_numTArgs = _tArgs.size();
|
||||
_numIArgs = _iArgs.size();
|
||||
_numBArgs = _bArgs.size();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// assignment operator
|
||||
OpArgsHolder& OpArgsHolder::operator=(const OpArgsHolder& other) {
|
||||
|
||||
throw std::runtime_error("OpArgsHolder::OpArgsHolder assignment operator: don't use me !");
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// move assignment operator
|
||||
OpArgsHolder& OpArgsHolder::operator=(OpArgsHolder&& other) noexcept {
|
||||
|
||||
if (this == &other)
|
||||
return *this;
|
||||
|
||||
for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary
|
||||
if(_isArrAlloc[i])
|
||||
delete _inArrs[i];
|
||||
|
||||
_inArrs = std::move(other._inArrs);
|
||||
_tArgs = std::move(other._tArgs);
|
||||
_iArgs = std::move(other._iArgs);
|
||||
_bArgs = std::move(other._bArgs);
|
||||
_isArrAlloc = std::move(other._isArrAlloc);
|
||||
|
||||
other._isArrAlloc = std::vector<bool>();
|
||||
|
||||
_numInArrs = _inArrs.size();
|
||||
_numTArgs = _tArgs.size();
|
||||
_numIArgs = _iArgs.size();
|
||||
_numBArgs = _bArgs.size();
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector<NDArray*>& inGradArrs, const bool isInPlace) const {
|
||||
|
||||
|
||||
const int numInGradArrs = inGradArrs.size();
|
||||
|
||||
OpArgsHolder result(std::vector<NDArray*>(_numInArrs + numInGradArrs, nullptr), _tArgs, _iArgs);
|
||||
|
||||
|
||||
if(isInPlace)
|
||||
result._isArrAlloc = std::vector<bool>(_numInArrs + numInGradArrs, false);
|
||||
|
||||
for (int i = 0; i < _numInArrs; ++i) {
|
||||
|
||||
if(isInPlace) {
|
||||
|
||||
if(isInPlace) {
|
||||
result._inArrs[i] = new NDArray(*_inArrs[i]); // make copy
|
||||
result._isArrAlloc[i] = true;
|
||||
}
|
||||
else
|
||||
result._inArrs[i] = _inArrs[i];
|
||||
else
|
||||
result._inArrs[i] = _inArrs[i];
|
||||
}
|
||||
|
||||
// input gradients
|
||||
// input gradients
|
||||
for (int i = 0; i < numInGradArrs; ++i)
|
||||
result._inArrs[_numInArrs + i] = inGradArrs[i];
|
||||
|
||||
|
@ -53,11 +148,10 @@ OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector<NDArray*>& in
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
// default destructor
|
||||
OpArgsHolder::~OpArgsHolder() noexcept {
|
||||
|
||||
|
||||
for (int i = 0; i < _isArrAlloc.size(); ++i)
|
||||
if(_isArrAlloc[i])
|
||||
delete _inArrs[i];
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
|||
auto input = INPUT_VARIABLE(i);
|
||||
auto currentRank = input->rankOf();
|
||||
|
||||
// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy
|
||||
// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy
|
||||
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
|
||||
// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank);
|
||||
if(!input->isEmpty()) {
|
||||
|
|
|
@ -1147,8 +1147,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
// gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
|
||||
// gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
||||
|
||||
gradI.nullify();
|
||||
|
||||
const T* x = gradO.bufferAsT<T>();
|
||||
T* z = gradI.bufferAsT<T>();
|
||||
|
||||
|
@ -1182,8 +1180,10 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
|
||||
const auto zOffset = b*zStride0 + c*zStride1 + h*zStride2 + w*zStride3;
|
||||
|
||||
for(uint xh = h; xh < h + factorH; ++xh)
|
||||
for(uint xw = w; xw < w + factorW; ++xw)
|
||||
z[zOffset] = 0;
|
||||
|
||||
for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh)
|
||||
for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw)
|
||||
z[zOffset] += x[b*xStride0 + c*xStride1 + xh*xStride2 + xw*xStride3];
|
||||
}
|
||||
}
|
||||
|
@ -1198,8 +1198,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
// input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
||||
// output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
|
||||
|
||||
gradI.nullify();
|
||||
|
||||
const T* x = gradO.bufferAsT<T>();
|
||||
T* z = gradI.bufferAsT<T>();
|
||||
|
||||
|
@ -1238,9 +1236,11 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
|
||||
const auto zOffset = b*zStride0 + c*zStride1 + d*zStride2 + h*zStride3 + w*zStride4;
|
||||
|
||||
for(uint xd = d; xd < d + factorD; ++xd)
|
||||
for(uint xh = h; xh < h + factorH; ++xh)
|
||||
for(uint xw = w; xw < w + factorW; ++xw)
|
||||
z[zOffset] = 0;
|
||||
|
||||
for(uint xd = d * factorD; xd < d * factorD + factorD; ++xd)
|
||||
for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh)
|
||||
for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw)
|
||||
z[zOffset] += x[b*xStride0 + c*xStride1 + xd*xStride2 + xh*xStride3 + xw*xStride4];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
||||
|
||||
template <typename T>
|
||||
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
|
||||
|
@ -114,7 +115,7 @@ namespace helpers {
|
|||
|
||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
||||
NDArray compoundMatrix = *input; // copy
|
||||
NDArray permutationMatrix(input, false, input->getContext()); // has same shape as input and contiguous strides
|
||||
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
|
||||
permutationMatrix.setIdentity();
|
||||
|
||||
T pivotValue; // = T(0.0);
|
||||
|
@ -170,7 +171,7 @@ namespace helpers {
|
|||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace());
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
||||
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
||||
|
@ -184,6 +185,7 @@ namespace helpers {
|
|||
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
@ -193,7 +195,7 @@ template <typename T>
|
|||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
|
||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace());
|
||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
||||
matrix.p(row, input->e<T>(k));
|
||||
|
@ -220,11 +222,11 @@ template <typename T>
|
|||
auto totalCount = output->lengthOf() / n2;
|
||||
|
||||
output->assign(0.f); // fill up output tensor with zeros
|
||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); //, block.getWorkspace());
|
||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext()); //, block.getWorkspace());
|
||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext());
|
||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext());
|
||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), input->getContext());
|
||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
if (e)
|
||||
|
@ -266,6 +268,7 @@ template <typename T>
|
|||
}
|
||||
|
||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
@ -308,8 +311,8 @@ template <typename T>
|
|||
if (!inplace)
|
||||
output->assign(0.f); // fill up output tensor with zeros only inplace=false
|
||||
|
||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), input->getContext())); //, block.getWorkspace());
|
||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), input->getContext()));
|
||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
|
||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
|
||||
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
|
||||
|
@ -343,6 +346,7 @@ template <typename T>
|
|||
}
|
||||
|
||||
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
//
|
||||
|
||||
#include <ops/declarable/helpers/sg_cb.h>
|
||||
#include <AveragingArrayProxy.h>
|
||||
#include <helpers/AveragingArrayProxy.h>
|
||||
#include <specials.h>
|
||||
|
||||
#define HS_MAX_EXP 6.0f
|
||||
|
|
|
@ -1496,8 +1496,10 @@ __global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShape
|
|||
|
||||
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||
|
||||
const Nd4jLong zCoord2 = coords[dimIH];
|
||||
const Nd4jLong zCoord3 = coords[dimIH + 1];
|
||||
z[zOffset] = 0;
|
||||
|
||||
const Nd4jLong zCoord2 = coords[dimIH] * factorH;
|
||||
const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW;
|
||||
|
||||
for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH])
|
||||
for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1])
|
||||
|
@ -1569,9 +1571,11 @@ __global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShape
|
|||
|
||||
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||
|
||||
const Nd4jLong zCoord2 = coords[dimID];
|
||||
const Nd4jLong zCoord3 = coords[dimID + 1];
|
||||
const Nd4jLong zCoord4 = coords[dimID + 2];
|
||||
z[zOffset] = 0;
|
||||
|
||||
const Nd4jLong zCoord2 = coords[dimID] * factorD;
|
||||
const Nd4jLong zCoord3 = coords[dimID + 1] * factorH;
|
||||
const Nd4jLong zCoord4 = coords[dimID + 2] * factorW;
|
||||
|
||||
for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID])
|
||||
for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1])
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -33,19 +33,19 @@ __global__ static void zetaCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
|
||||
const auto x = reinterpret_cast<const T*>(vx);
|
||||
const auto q = reinterpret_cast<const T*>(vq);
|
||||
auto z = reinterpret_cast<T*>(vz);
|
||||
auto z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ Nd4jLong len;
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
len = shape::length(xShapeInfo);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
len = shape::length(xShapeInfo);
|
||||
__syncthreads();
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
for (int i = tid; i < len; i += totalThreads) {
|
||||
|
||||
|
||||
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, len);
|
||||
const auto qOffset = shape::getIndexOffset(i, qShapeInfo, len);
|
||||
const auto zOffset = shape::getIndexOffset(i, zShapeInfo, len);
|
||||
|
@ -65,10 +65,10 @@ void zeta(nd4j::LaunchContext * context, const NDArray& x, const NDArray& q, NDA
|
|||
|
||||
if(!x.isActualOnDeviceSide()) x.syncToDevice();
|
||||
if(!q.isActualOnDeviceSide()) q.syncToDevice();
|
||||
|
||||
int threadsPerBlock = MAX_NUM_THREADS;
|
||||
|
||||
int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
|
||||
BUILD_SINGLE_SELECTOR(x.dataType(), zetaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), q.getSpecialBuffer(), q.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||
|
||||
x.tickReadHost();
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "testlayers.h"
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <AveragingArrayProxy.h>
|
||||
|
||||
using namespace nd4j;
|
||||
using namespace nd4j::ops;
|
||||
using namespace nd4j::graph;
|
||||
|
||||
class AveragingArrayTests : public testing::Test {
|
||||
public:
|
||||
|
||||
};
|
||||
|
||||
TEST_F(AveragingArrayTests, test_basic_reads_1) {
|
||||
auto exp0 = NDArrayFactory::create<double>('c', {1, 5},{3.0, 3.0, 3.0, 3.0, 3.0});
|
||||
|
||||
auto original = NDArrayFactory::create<double>('c', {100, 5});
|
||||
original.assign(1.0);
|
||||
|
||||
AveragingArrayProxy proxy(&original);
|
||||
|
||||
auto writeable0 = proxy.writeable(1, 0);
|
||||
auto writeable1 = proxy.writeable(1, 1);
|
||||
auto writeable2 = proxy.writeable(2, 1);
|
||||
|
||||
ASSERT_FALSE(writeable0 == nullptr);
|
||||
ASSERT_FALSE(writeable1 == nullptr);
|
||||
ASSERT_FALSE(writeable2 == nullptr);
|
||||
|
||||
writeable0->assign(2.0);
|
||||
writeable1->assign(4.0);
|
||||
writeable2->assign(3.0);
|
||||
|
||||
auto r = proxy.collapseWrites();
|
||||
|
||||
ASSERT_TRUE(r);
|
||||
|
||||
auto row1 = original({1,2, 0,0}, true);
|
||||
auto row2 = original({2,3, 0,0}, true);
|
||||
|
||||
ASSERT_EQ(exp0, row1);
|
||||
ASSERT_EQ(exp0, row2);
|
||||
}
|
|
@ -159,9 +159,9 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) {
|
|||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
||||
|
||||
|
||||
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
|
||||
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
|
||||
152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
|
||||
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
|
||||
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
|
||||
152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
|
||||
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f});
|
||||
input = 2.;
|
||||
weights.linspace(0.1, 0.1);
|
||||
|
@ -2211,55 +2211,6 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests1, upsampling2d_bp_test1) {
|
||||
|
||||
const int bS=1, iH=2,iW=2, iC=1;
|
||||
const int factorH=2, factorW=2;
|
||||
const int isNCHW = 1; // data format, default is NCHW
|
||||
|
||||
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
||||
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW});
|
||||
gradO = 1.;
|
||||
|
||||
auto expGradI = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
||||
expGradI = 4.;
|
||||
|
||||
nd4j::ops::upsampling2d_bp op;
|
||||
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
|
||||
auto* gradI = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests1, upsampling2d_bp_test2) {
|
||||
|
||||
const int bS=1, iH=2,iW=2, iC=1;
|
||||
const int factorH=2, factorW=2;
|
||||
const int isNCHW = 0; // data format, default is NCHW
|
||||
|
||||
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||
auto gradO = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC});
|
||||
gradO = 1.;
|
||||
|
||||
auto expGradI = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||
expGradI = 4.;
|
||||
|
||||
nd4j::ops::upsampling2d_bp op;
|
||||
auto results = op.execute({&input, &gradO}, {}, {isNCHW});
|
||||
auto* gradI = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests1, upsampling3d_bp_test1) {
|
||||
|
@ -2315,6 +2266,70 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
|
||||
|
||||
const int bS=1, iD=3,iH=3,iW=3, iC=2;
|
||||
const int factorD=2, factorH=2, factorW=2;
|
||||
const int isNCDHW = 1; // data format, default is NCHW
|
||||
|
||||
NDArray input('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338,
|
||||
0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668,
|
||||
0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439,
|
||||
0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839,
|
||||
0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561,
|
||||
0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631,
|
||||
0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227,
|
||||
0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047,
|
||||
0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033,
|
||||
0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843,
|
||||
0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876,
|
||||
0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908,
|
||||
0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415,
|
||||
0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304,
|
||||
0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846,
|
||||
0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239,
|
||||
0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083,
|
||||
0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075,
|
||||
0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565,
|
||||
0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615,
|
||||
0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824,
|
||||
0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565,
|
||||
0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802,
|
||||
0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821,
|
||||
0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676,
|
||||
0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527,
|
||||
0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158,
|
||||
0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731,
|
||||
0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452,
|
||||
0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176,
|
||||
0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142,
|
||||
0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622,
|
||||
0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457,
|
||||
0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804,
|
||||
0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821,
|
||||
0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333,
|
||||
0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278,
|
||||
3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016,
|
||||
4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917,
|
||||
4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856,
|
||||
4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::upsampling3d_bp op;
|
||||
auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
|
||||
auto* gradI = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ConvolutionTests1, deconv2d_test1) {
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -640,40 +640,6 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) {
|
|||
delete block;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, MergeMaxTest1) {
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {5, 5});
|
||||
auto y = NDArrayFactory::create_<float>('c', {5, 5});
|
||||
auto z = NDArrayFactory::create_<float>('c', {5, 5});
|
||||
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||
x->assign(3);
|
||||
y->assign(1);
|
||||
z->assign(2);
|
||||
exp.assign(3);
|
||||
|
||||
auto zu = NDArrayFactory::create<float>('c', {5, 5});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
variableSpace->putVariable(-2, y);
|
||||
variableSpace->putVariable(-3, z);
|
||||
variableSpace->putVariable(1, new Variable(NDArrayFactory::create_<float>('c', {5, 5})));
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1, -2, -3});
|
||||
|
||||
nd4j::ops::mergemax merge;
|
||||
|
||||
merge.execute(block);
|
||||
|
||||
auto res = variableSpace->getVariable(1)->getNDArray();
|
||||
|
||||
ASSERT_TRUE(res->equalsTo(&exp));
|
||||
|
||||
delete block;
|
||||
delete variableSpace;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, MergeAvgTest1) {
|
||||
|
||||
|
|
|
@ -845,5 +845,44 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, mergemax_1) {
|
||||
|
||||
NDArray x1('c', {5, 5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {5, 5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x3('c', {5, 5}, nd4j::DataType::FLOAT32);
|
||||
NDArray e('c', {5, 5}, nd4j::DataType::FLOAT32);
|
||||
x1.assign(3);
|
||||
x2.assign(1);
|
||||
x3.assign(2);
|
||||
e.assign(3);
|
||||
|
||||
|
||||
nd4j::ops::mergemax op;
|
||||
auto result = op.execute({&x1, &x2, &x3}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
// z->printBuffer();
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, mergemax_2) {
|
||||
|
||||
NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {1, 1}, {1.}, nd4j::DataType::FLOAT32);
|
||||
NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::mergemax op;
|
||||
auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(20, status);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1548,8 +1548,8 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
//z->printIndexedBuffer("Output ");
|
||||
//exp.printIndexedBuffer("Expected ");
|
||||
z->printIndexedBuffer("Log ABS Output ");
|
||||
exp.printIndexedBuffer("Log ABS Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
@ -1671,6 +1671,40 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
|
||||
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||
1., 0., 0., 0., 0.,
|
||||
2., 1., 0., 0., 0.,
|
||||
30., 2., 1., 0., 0.,
|
||||
4., 3., 2., 1., 0.,
|
||||
5., 4., 3., 2., 1.,
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||
1.0, 0.0, 0.0, 0.0, 0.,
|
||||
-2.0, 1.0, 0., 0., 0.,
|
||||
-26.0, -2.0, 1, 0, 0.,
|
||||
54.0, 1.0, -2.0, 1, 0.,
|
||||
-27.0, 0.0, 1.0, -2.0, 1.
|
||||
});
|
||||
|
||||
nd4j::ops::matrix_inverse op;
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
// z->printIndexedBuffer("010 Output ");
|
||||
// exp.printIndexedBuffer("010 Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
|
||||
|
||||
|
@ -1824,7 +1858,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {5, 5}, {
|
||||
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
4., 0., 0., 0., 0.,
|
||||
4., 2., 0., 0., 0.,
|
||||
30., 2., 1., 0., 0.,
|
||||
|
@ -1832,7 +1866,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
|||
15., 12., 9., 6., 3.,
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {5, 5}, {
|
||||
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
0.25, 0.0, 0.0, 0.0, 0.0,
|
||||
-0.50, 0.5, 0.0, 0.0, 0.0,
|
||||
-6.50, -1.0, 1.0, 0.0, 0.0,
|
||||
|
@ -1841,13 +1875,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
|||
});
|
||||
|
||||
nd4j::ops::matrix_inverse op;
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
exp.printIndexedBuffer("Expected ");
|
||||
z->printIndexedBuffer("Output ");
|
||||
// exp.printIndexedBuffer("Expected ");
|
||||
// z->printIndexedBuffer("Output ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
@ -1880,8 +1914,42 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
z->printIndexedBuffer("Output ");
|
||||
exp.printIndexedBuffer("Expected ");
|
||||
// z->printIndexedBuffer("Output ");
|
||||
// exp.printIndexedBuffer("Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
|
||||
|
||||
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
1., 2., 30., 4., 5.,
|
||||
0., 1., 2., 3., 4.,
|
||||
0., 0., 1., 2., 3.,
|
||||
0., 0., 0., 1., 2.,
|
||||
0., 0., 0., 0., 1.
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
1.0, -2.0, -26.0, 54.0, -27.0,
|
||||
0.0, 1.0, -2.0, 1.0, 0.0,
|
||||
0.0, 0.0, 1.0, -2.0, 1.0,
|
||||
0.0, 0.0, 0.0, 1.0, -2.0,
|
||||
0.0, 0.0, 0.0, 0.0, 1.0
|
||||
});
|
||||
|
||||
nd4j::ops::matrix_inverse op;
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
// z->printIndexedBuffer("Output ");
|
||||
// exp.printIndexedBuffer("Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
|
|
@ -9363,15 +9363,28 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
// #include <NDArray.h>
|
||||
// #include <dll.h>
|
||||
|
||||
|
||||
@Namespace("nd4j") @NoOffset public static class OpArgsHolder extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public OpArgsHolder(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public OpArgsHolder position(long position) {
|
||||
return (OpArgsHolder)super.position(position);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// default constructor
|
||||
public OpArgsHolder() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
|
||||
// copy constructor
|
||||
public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); }
|
||||
private native void allocate(@Const @ByRef OpArgsHolder other);
|
||||
|
||||
// constructor
|
||||
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
|
||||
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/);
|
||||
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); }
|
||||
|
@ -9387,6 +9400,13 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
|
||||
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/);
|
||||
|
||||
// move constructor
|
||||
|
||||
// assignment operator
|
||||
public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other);
|
||||
|
||||
// move assignment operator
|
||||
|
||||
public native @Const @ByRef NDArrayVector getInArrs();
|
||||
|
||||
public native @StdVector DoublePointer getTArgs();
|
||||
|
@ -9406,8 +9426,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public native int getNumBArgs();
|
||||
|
||||
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/);
|
||||
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
|
||||
|
||||
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -9060,15 +9060,28 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
// #include <NDArray.h>
|
||||
// #include <dll.h>
|
||||
|
||||
|
||||
@Namespace("nd4j") @NoOffset public static class OpArgsHolder extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public OpArgsHolder(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public OpArgsHolder position(long position) {
|
||||
return (OpArgsHolder)super.position(position);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// default constructor
|
||||
public OpArgsHolder() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
|
||||
// copy constructor
|
||||
public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); }
|
||||
private native void allocate(@Const @ByRef OpArgsHolder other);
|
||||
|
||||
// constructor
|
||||
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
|
||||
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/);
|
||||
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); }
|
||||
|
@ -9084,6 +9097,13 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); }
|
||||
private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector<double>()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector<Nd4jLong>()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/);
|
||||
|
||||
// move constructor
|
||||
|
||||
// assignment operator
|
||||
public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other);
|
||||
|
||||
// move assignment operator
|
||||
|
||||
public native @Const @ByRef NDArrayVector getInArrs();
|
||||
|
||||
public native @StdVector DoublePointer getTArgs();
|
||||
|
@ -9103,8 +9123,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public native int getNumBArgs();
|
||||
|
||||
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/);
|
||||
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
|
||||
|
||||
public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
|
|||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -548,6 +549,8 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
Nd4j.exec(op); //Execution is OK
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testDepthwise(){
|
||||
INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8);
|
||||
|
@ -625,6 +628,49 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
System.out.println(out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUpsampling2dBackprop(){
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int c = 2;
|
||||
int[] sz = {2,2};
|
||||
long[] inSize = {1, c, 3, 3};
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, 1, c, sz[0] * inSize[2], sz[1] * inSize[3]);
|
||||
|
||||
INDArray input = Nd4j.create(inSize); //Unused, not sure why this is even an arg...
|
||||
INDArray exp = Nd4j.create(DataType.FLOAT, inSize);
|
||||
|
||||
for( int ch=0; ch<c; ch++ ) {
|
||||
for( int h=0; h<eps.size(2); h++ ){
|
||||
for( int w=0; w<eps.size(3); w++ ){
|
||||
int[] from = new int[]{0, ch, h, w};
|
||||
int[] to = new int[]{0, ch, h/sz[0], w/sz[1]};
|
||||
float add = eps.getFloat(from);
|
||||
float current = exp.getFloat(to);
|
||||
exp.putScalar(to, current + add);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Eps:");
|
||||
System.out.println(eps.shapeInfoToString());
|
||||
System.out.println(Arrays.toString(eps.data().asFloat()));
|
||||
|
||||
System.out.println("Expected:");
|
||||
System.out.println(exp.shapeInfoToString());
|
||||
System.out.println(Arrays.toString(exp.data().asFloat()));
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("upsampling2d_bp")
|
||||
.addInputs(input, eps)
|
||||
.addOutputs(exp.ulike())
|
||||
.addIntegerArguments(1) //1 = NCHW
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
|
||||
INDArray act = op.getOutputArgument(0);
|
||||
assertEquals(exp, act);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsMaxView(){
|
||||
|
|
Loading…
Reference in New Issue