diff --git a/libnd4j/include/helpers/AveragingArrayProxy.h b/libnd4j/include/helpers/AveragingArrayProxy.h deleted file mode 100644 index 58709e7df..000000000 --- a/libnd4j/include/helpers/AveragingArrayProxy.h +++ /dev/null @@ -1,59 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - - -#ifndef DEV_TESTS_AVERAGINGARRAYPROXY_H -#define DEV_TESTS_AVERAGINGARRAYPROXY_H - -#include "NDArray.h" -#include -#include -#include -#include - -namespace nd4j { - class ND4J_EXPORT AveragingArrayProxy { - protected: - NDArray *_original; - - std::map, NDArray*> _writeables; - std::map> _writeablesLinear; - std::vector _rows; - - std::vector _references; - - std::mutex _lock; - public: - explicit AveragingArrayProxy(NDArray *original); - ~AveragingArrayProxy(); - - NDArray* readable(int row, int key); - NDArray* writeable(int row, int key); - - bool isEmpty(); - - bool writeableExists(std::pair &key); - bool writeableExists(int row, int key); - - bool collapseWrites(); - }; -} - -#endif //DEV_TESTS_AVERAGINGARRAYPROXY_H diff --git a/libnd4j/include/helpers/OpArgsHolder.h b/libnd4j/include/helpers/OpArgsHolder.h index 6cf366f34..5d792105c 100644 --- a/libnd4j/include/helpers/OpArgsHolder.h +++ b/libnd4j/include/helpers/OpArgsHolder.h @@ -26,27 +26,42 @@ #include namespace nd4j { - + class ND4J_EXPORT OpArgsHolder { -private: +private: + std::vector _inArrs = std::vector(); - std::vector _tArgs = std::vector(); - std::vector _iArgs = std::vector(); - std::vector _bArgs = std::vector(); + std::vector _tArgs = std::vector(); + std::vector _iArgs = std::vector(); + std::vector _bArgs = std::vector(); + + std::vector _isArrAlloc = std::vector(); int _numInArrs = _inArrs.size(); int _numTArgs = _tArgs.size(); int _numIArgs = _iArgs.size(); int _numBArgs = _bArgs.size(); - std::vector _isArrAlloc = std::vector(); public: - OpArgsHolder() = delete; + // default constructor + OpArgsHolder(); - OpArgsHolder(const std::vector& inArrs, const std::vector& tArgs = std::vector(), const std::vector& iArgs = std::vector(), const std::vector& bArgs = std::vector()) - : _inArrs(inArrs), _tArgs(tArgs), _iArgs(iArgs), _bArgs(bArgs) { } + // copy constructor + OpArgsHolder(const OpArgsHolder& other); + + // constructor + OpArgsHolder(const std::vector& inArrs, const std::vector& tArgs = std::vector(), const std::vector& iArgs = std::vector(), const std::vector& bArgs = std::vector()); + + // move constructor + OpArgsHolder(OpArgsHolder&& other) noexcept; + + // assignment operator + OpArgsHolder& operator=(const OpArgsHolder& other); + + // move assignment operator + OpArgsHolder& operator=(OpArgsHolder&& other) noexcept; const std::vector& getInArrs() const {return _inArrs; } @@ -77,8 +92,8 @@ public: OpArgsHolder createArgsHolderForBP(const std::vector& inGradArrs, const bool isInPlace = false) const; - ~OpArgsHolder() noexcept; - + ~OpArgsHolder() noexcept; + }; diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index ac5eb4176..dda709545 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -229,7 +229,6 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); } else if(ABC && aType == DataType::HALF) { - printf("!!!!!!!!\n"); float16 alphaH(alpha), betaH(beta); status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); } diff --git a/libnd4j/include/helpers/impl/AveragingArrayProxy.cpp b/libnd4j/include/helpers/impl/AveragingArrayProxy.cpp deleted file mode 100644 index 9ebaeb443..000000000 --- a/libnd4j/include/helpers/impl/AveragingArrayProxy.cpp +++ /dev/null @@ -1,139 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../AveragingArrayProxy.h" - -namespace nd4j { - AveragingArrayProxy::AveragingArrayProxy(NDArray *original) { - _original = original; - } - - AveragingArrayProxy::~AveragingArrayProxy() { - for (auto v:_references) - delete v; - } - - bool AveragingArrayProxy::writeableExists(std::pair &key) { - _lock.lock(); - - auto r = _writeables.count(key) > 0; - - _lock.unlock(); - - return r; - } - - bool AveragingArrayProxy::writeableExists(int row, int key) { - std::pair k(row, key); - return writeableExists(k); - } - - NDArray* AveragingArrayProxy::readable(int row, int key) { - std::pair k(row, key); - - if (writeableExists(k)) { - _lock.lock(); - - auto r = _writeables[k]; - - _lock.unlock(); - - return r; - } else { - auto readable = (*_original)({row,row+1, 0,0}); - - _lock.lock(); - - _references.emplace_back(&readable); - - _lock.unlock(); - - // return readable; - } - } - - bool AveragingArrayProxy::isEmpty() { - return _original->isEmpty(); - } - - NDArray* AveragingArrayProxy::writeable(int row, int key) { - std::pair k(row, key); - - // if writeable exists - just return it - if (writeableExists(k)) { - _lock.lock(); - - auto r = _writeables[k]; - - _lock.unlock(); - - return r; - } else { - // if doesn't - let's create it - auto orig = (*_original)({row,row+1, 0,0}); - - // we don't want views here for obvious reasons - auto writeable = orig.dup('c'); - - _lock.lock(); - - _writeables[k] = writeable; - _references.emplace_back(writeable); - - // storing linear reference, for future averaging step - if (_writeablesLinear.count(row) == 0) { - std::vector vec; - _writeablesLinear[row] = vec; - } - - _writeablesLinear[row].emplace_back(writeable); - _rows.emplace_back(row); - - _lock.unlock(); - - return writeable; - } - } - - bool AveragingArrayProxy::collapseWrites() { - if (_writeables.empty()) - return false; - - for (int r = 0; r < _rows.size(); r++) { - auto row = _rows[r]; - auto list = _writeablesLinear.at(row); - - auto originalRow = (*_original)({row,row+1, 0,0}); - - if (list.size() == 1) { - originalRow.assign(list.at(0)); - } else { - originalRow.assign(0.0); - - for (int e = 0; e < list.size(); e++) - originalRow += *(list.at(e)); - - originalRow /= (int) list.size(); - } - } - - return true; - } -} \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/OpArgsHolder.cpp b/libnd4j/include/helpers/impl/OpArgsHolder.cpp index 984273bca..816253bc6 100644 --- a/libnd4j/include/helpers/impl/OpArgsHolder.cpp +++ b/libnd4j/include/helpers/impl/OpArgsHolder.cpp @@ -23,27 +23,122 @@ namespace nd4j { +//////////////////////////////////////////////////////////////////////// +// default constructor +OpArgsHolder::OpArgsHolder() { + + _inArrs = std::vector(); + _tArgs = std::vector(); + _iArgs = std::vector(); + _bArgs = std::vector(); + + _isArrAlloc = std::vector(); + + _numInArrs = 0; + _numTArgs = 0; + _numIArgs = 0; + _numBArgs = 0; +} + +//////////////////////////////////////////////////////////////////////// +// copy constructor +OpArgsHolder::OpArgsHolder(const OpArgsHolder& other) { + + throw std::runtime_error("OpArgsHolder::OpArgsHolder copy constructor: don't use me !"); +} + + +//////////////////////////////////////////////////////////////////////// +// constructor +OpArgsHolder::OpArgsHolder(const std::vector& inArrs, + const std::vector& tArgs, + const std::vector& iArgs, + const std::vector& bArgs) { + _inArrs = inArrs; + _tArgs = tArgs; + _iArgs = iArgs; + _bArgs = bArgs; + + _isArrAlloc = std::vector(); + + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); +} + +//////////////////////////////////////////////////////////////////////// +// move constructor +OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept: _inArrs(std::move(other._inArrs)), + _tArgs(std::move(other._tArgs)), + _iArgs(std::move(other._iArgs)), + _bArgs(std::move(other._bArgs)), + _isArrAlloc(std::move(other._isArrAlloc)) { + + other._isArrAlloc = std::vector(); + + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); +} + +//////////////////////////////////////////////////////////////////////// +// assignment operator +OpArgsHolder& OpArgsHolder::operator=(const OpArgsHolder& other) { + + throw std::runtime_error("OpArgsHolder::OpArgsHolder assignment operator: don't use me !"); +} + + +//////////////////////////////////////////////////////////////////////// +// move assignment operator +OpArgsHolder& OpArgsHolder::operator=(OpArgsHolder&& other) noexcept { + + if (this == &other) + return *this; + + for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary + if(_isArrAlloc[i]) + delete _inArrs[i]; + + _inArrs = std::move(other._inArrs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _isArrAlloc = std::move(other._isArrAlloc); + + other._isArrAlloc = std::vector(); + + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); + + return *this; +} + //////////////////////////////////////////////////////////////////////// OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector& inGradArrs, const bool isInPlace) const { - + const int numInGradArrs = inGradArrs.size(); OpArgsHolder result(std::vector(_numInArrs + numInGradArrs, nullptr), _tArgs, _iArgs); - + if(isInPlace) result._isArrAlloc = std::vector(_numInArrs + numInGradArrs, false); for (int i = 0; i < _numInArrs; ++i) { - - if(isInPlace) { + + if(isInPlace) { result._inArrs[i] = new NDArray(*_inArrs[i]); // make copy result._isArrAlloc[i] = true; } - else - result._inArrs[i] = _inArrs[i]; + else + result._inArrs[i] = _inArrs[i]; } - // input gradients + // input gradients for (int i = 0; i < numInGradArrs; ++i) result._inArrs[_numInArrs + i] = inGradArrs[i]; @@ -53,11 +148,10 @@ OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector& in //////////////////////////////////////////////////////////////////////// // default destructor OpArgsHolder::~OpArgsHolder() noexcept { - + for (int i = 0; i < _isArrAlloc.size(); ++i) if(_isArrAlloc[i]) delete _inArrs[i]; - } } diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 5eb278f68..5249758bf 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -44,7 +44,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) { auto input = INPUT_VARIABLE(i); auto currentRank = input->rankOf(); -// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy +// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy // REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank); // REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank); if(!input->isEmpty()) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index 033e0b5e5..dd5516461 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -1147,8 +1147,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - gradI.nullify(); - const T* x = gradO.bufferAsT(); T* z = gradI.bufferAsT(); @@ -1182,8 +1180,10 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( const auto zOffset = b*zStride0 + c*zStride1 + h*zStride2 + w*zStride3; - for(uint xh = h; xh < h + factorH; ++xh) - for(uint xw = w; xw < w + factorW; ++xw) + z[zOffset] = 0; + + for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw) z[zOffset] += x[b*xStride0 + c*xStride1 + xh*xStride2 + xw*xStride3]; } } @@ -1198,8 +1198,6 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - gradI.nullify(); - const T* x = gradO.bufferAsT(); T* z = gradI.bufferAsT(); @@ -1238,9 +1236,11 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( const auto zOffset = b*zStride0 + c*zStride1 + d*zStride2 + h*zStride3 + w*zStride4; - for(uint xd = d; xd < d + factorD; ++xd) - for(uint xh = h; xh < h + factorH; ++xh) - for(uint xw = w; xw < w + factorW; ++xw) + z[zOffset] = 0; + + for(uint xd = d * factorD; xd < d * factorD + factorD; ++xd) + for(uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for(uint xw = w * factorW; xw < w * factorW + factorW; ++xw) z[zOffset] += x[b*xStride0 + c*xStride1 + xd*xStride2 + xh*xStride3 + xw*xStride4]; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 1fb1ef1df..ee9a78cee 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -26,6 +26,7 @@ namespace nd4j { namespace ops { namespace helpers { + nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext(); template static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { @@ -114,7 +115,7 @@ namespace helpers { NDArray determinant = NDArrayFactory::create(1.f); NDArray compoundMatrix = *input; // copy - NDArray permutationMatrix(input, false, input->getContext()); // has same shape as input and contiguous strides + NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides permutationMatrix.setIdentity(); T pivotValue; // = T(0.0); @@ -170,7 +171,7 @@ namespace helpers { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) @@ -184,6 +185,7 @@ namespace helpers { BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES); int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { + defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES); } @@ -193,7 +195,7 @@ template Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace()); + NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { matrix.p(row, input->e(k)); @@ -220,11 +222,11 @@ template auto totalCount = output->lengthOf() / n2; output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), input->getContext()); //, block.getWorkspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), input->getContext()); //, block.getWorkspace()); - auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), input->getContext()); - auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), input->getContext()); - auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), input->getContext()); + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); //, block.getWorkspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); //, block.getWorkspace()); + auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); + auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); + auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); for (int e = 0; e < totalCount; e++) { if (e) @@ -266,6 +268,7 @@ template } int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { + defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES); } @@ -308,8 +311,8 @@ template if (!inplace) output->assign(0.f); // fill up output tensor with zeros only inplace=false - std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), input->getContext())); //, block.getWorkspace()); - std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), input->getContext())); + std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace()); + std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext)); for (int e = 0; e < totalCount; e++) { @@ -343,6 +346,7 @@ template } int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { + defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp index b96b367c1..b0fd449c7 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp @@ -19,8 +19,6 @@ // #include -#include -#include #include #define HS_MAX_EXP 6.0f diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 991cdb660..e224329f0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -1496,8 +1496,10 @@ __global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShape const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); - const Nd4jLong zCoord2 = coords[dimIH]; - const Nd4jLong zCoord3 = coords[dimIH + 1]; + z[zOffset] = 0; + + const Nd4jLong zCoord2 = coords[dimIH] * factorH; + const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW; for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH]) for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1]) @@ -1569,9 +1571,11 @@ __global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShape const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); - const Nd4jLong zCoord2 = coords[dimID]; - const Nd4jLong zCoord3 = coords[dimID + 1]; - const Nd4jLong zCoord4 = coords[dimID + 2]; + z[zOffset] = 0; + + const Nd4jLong zCoord2 = coords[dimID] * factorD; + const Nd4jLong zCoord3 = coords[dimID + 1] * factorH; + const Nd4jLong zCoord4 = coords[dimID + 2] * factorW; for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID]) for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1]) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 354f360c3..bf9c73e7c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -31,6 +31,7 @@ namespace nd4j { namespace ops { namespace helpers { + nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext(); // template // static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { @@ -55,10 +56,11 @@ namespace helpers { // void swapRows(NDArray* matrix, int theFirst, int theSecond) { // BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES); // } - template - static __global__ void invertKernelLow(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { - T* inverted = reinterpret_cast(invertedBuf); - T* input = reinterpret_cast(inputBuf); + template + static __global__ void + invertKernelLow(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { + T *inverted = reinterpret_cast(invertedBuf); + T *input = reinterpret_cast(inputBuf); auto start = threadIdx.x + blockIdx.x * blockDim.x; auto step = blockDim.x * gridDim.x; @@ -76,10 +78,11 @@ namespace helpers { } } - template - static __global__ void upvertKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { - T* inverted = reinterpret_cast(invertedBuf); - T* input = reinterpret_cast(inputBuf); + template + static __global__ void + upvertKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { + T *inverted = reinterpret_cast(invertedBuf); + T *input = reinterpret_cast(inputBuf); auto start = threadIdx.x + blockIdx.x * blockDim.x; auto step = blockDim.x * gridDim.x; @@ -93,10 +96,26 @@ namespace helpers { } } - template - static __global__ void upvertKernelUp(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { - T* inverted = reinterpret_cast(invertedBuf); - T* input = reinterpret_cast(inputBuf); + template + static __global__ void + upvertKernelUp(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { + + __shared__ T* inverted; + __shared__ T* input; + __shared__ Nd4jLong* inputStride; + __shared__ Nd4jLong* invertedStride; + __shared__ Nd4jLong* invertedShapeOf; + __shared__ Nd4jLong* inputShapeOf; + if (threadIdx.x == 0) { + inverted = reinterpret_cast(invertedBuf); + input = reinterpret_cast(inputBuf); + inputStride = shape::stride(inputShape); + invertedStride = shape::stride(invertedShape); + invertedShapeOf = shape::shapeOf(invertedShape); + inputShapeOf = shape::shapeOf(inputShape); + + } + __syncthreads(); auto start = threadIdx.x + blockIdx.x * blockDim.x; auto step = blockDim.x * gridDim.x; @@ -105,20 +124,21 @@ namespace helpers { Nd4jLong pos[] = {i, i + 1}; //Nd4jLong posY[] = {i, i}; Nd4jLong posX[] = {i + 1, i + 1}; - auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); + auto xIndex = shape::getOffset(0, inputShapeOf, shape::stride(inputShape), pos, 2); // auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2); // auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); - auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2); - auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); - math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex]); // / input[yIndex]); + auto iIndex = shape::getOffset(0, invertedShapeOf, invertedStride, posX, 2); + auto zIndex = shape::getOffset(0, invertedShapeOf, invertedStride, pos, 2); + math::atomics::nd4j_atomicAdd(&inverted[zIndex], -input[xIndex] * inverted[iIndex]); // / input[yIndex]); //inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / inputMatrix->t(i, i) } } - template - static __global__ void invertLowKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { - T* inverted = reinterpret_cast(invertedBuf); - T* input = reinterpret_cast(inputBuf); + template + static __global__ void + invertLowKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { + T *inverted = reinterpret_cast(invertedBuf); + T *input = reinterpret_cast(inputBuf); for (int i = blockIdx.x + 2; i < n; i += gridDim.x) { for (int j = i - 2; j >= 0; --j) @@ -129,76 +149,101 @@ namespace helpers { Nd4jLong posD[] = {i, i}; auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); - auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); + auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, + 2); auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); - auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); - math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex] / input[dIndex]); + auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, + 2); + math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex] / input[dIndex]); } } } - template - static __global__ void invertUpKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) { - T* inverted = reinterpret_cast(invertedBuf);; - T* input = reinterpret_cast(inputBuf); + template + static __global__ void + invertUpKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { + __shared__ T* inverted; + __shared__ T* input; + __shared__ Nd4jLong* inputShapeOf; + __shared__ Nd4jLong* invertedShapeOf; + __shared__ Nd4jLong* invertedStrideOf; + __shared__ Nd4jLong* inputStrideOf; - for (int i = n - blockIdx.x - 2; i >= 0; i -= gridDim.x) { - for (int j = i + 2; j < n; j++) - for (int k = i + threadIdx.x; k < n; k+= blockDim.x) { + if (threadIdx.x == 0) { + inverted = reinterpret_cast(invertedBuf);; + input = reinterpret_cast(inputBuf); + inputShapeOf = shape::shapeOf(inputShape); + invertedShapeOf = shape::shapeOf(invertedShape); + inputStrideOf = shape::stride(inputShape); + invertedStrideOf = shape::stride(invertedShape); + } + __syncthreads(); + + for (int i = (int)n - blockIdx.x - 2; i >= 0; i -= gridDim.x) { + for (int j = i + 2; j < (int)n; j++) + for (int k = i + threadIdx.x; k < (int)n; k += blockDim.x) { Nd4jLong posZ[] = {i, j}; Nd4jLong posY[] = {k, j}; Nd4jLong posX[] = {i, k}; // Nd4jLong posD[] = {i, i}; - auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); - auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); - // auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); - auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); - math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex]);// / input[dIndex]); + auto xIndex = shape::getOffset(0, inputShapeOf, inputStrideOf, posX, 2); + auto yIndex = shape::getOffset(0, invertedShapeOf, invertedStrideOf, posY, 2); + // auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); + auto zIndex = shape::getOffset(0, invertedShapeOf, invertedStrideOf, posZ, 2); + math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex]);// / input[dIndex]); +// printf("(%d, %d) inverted[%lld] = %lf (-inverted[%lld] * input[%lld]\n", blockIdx.x, threadIdx.x, zIndex, inverted[zIndex], yIndex, xIndex); } } } - template - static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { + template + static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) { int n = inputMatrix->rows(); invertedMatrix->setIdentity(); if (inputMatrix->isIdentityMatrix()) return; - LaunchContext* context = inputMatrix->getContext(); - auto stream = context->getCudaStream(); + + auto stream = defaultContext->getCudaStream(); // invert main diagonal - upvertKernel<<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + upvertKernel << < 1, n, 512, *stream >> > + (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invert the second diagonal - invertKernelLow<<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertKernelLow << < 1, n, 512, *stream >> > + (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invertKernelLow<<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - invertLowKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertLowKernel<<< n, n, 512, *stream >> > + (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); } - template + template static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { int n = inputMatrix->rows(); invertedMatrix->setIdentity(); - auto stream = inputMatrix->getContext()->getCudaStream(); + auto stream = defaultContext->getCudaStream(); if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I return; } //upvertKernel<<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - upvertKernelUp<<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + upvertKernelUp<<<1, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertedMatrix->tickWriteDevice(); + invertedMatrix->printIndexedBuffer("Step1 UP inversion"); + invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); + BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); } @@ -244,305 +289,273 @@ namespace helpers { // } // } - template - static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) { - F tempRes = (F)result[0]; - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < len; i += step) { - auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); - math::atomics::nd4j_atomicMul(&tempRes, (F)compound[pos]); - } - __syncthreads(); - - if (threadIdx.x == 0) { - result[0] = (T)tempRes; - } - } - - template - static __global__ void determinantLogKernel(T* compound, T* result, Nd4jLong len) { - F tempRes = (F)result[0]; +// template + template + static __global__ void determinantKernel(T *compound, T *result, Nd4jLong len) { + //F tempRes = result[0]; auto start = blockIdx.x * blockDim.x + threadIdx.x; auto step = blockDim.x * gridDim.x; for (auto i = start; i < len; i += step) { auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); - math::atomics::nd4j_atomicMul(&tempRes, (F)compound[pos]); + math::atomics::nd4j_atomicMul(&result[0], compound[pos]); } - __syncthreads(); + } + + template + static __global__ void determinantLogKernel(T *compound, T *result, Nd4jLong len) { +// F tempRes = (F)result[0]; + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < len; i += step) { + auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); + math::atomics::nd4j_atomicAdd(result, math::nd4j_log(math::nd4j_abs(compound[pos]))); + } +// __syncthreads(); +// +// if (threadIdx.x == 0) { +// result[0] = (T)math::nd4j_log(math::nd4j_abs(tempRes)); +// } + } + + template + static __global__ void + fillMatrix(void *output, Nd4jLong *outShape, void *input, Nd4jLong *inputShape, Nd4jLong pos, Nd4jLong rowLen) { + __shared__ + F *matrix; + __shared__ + T *inputBuf; + __shared__ + Nd4jLong inputLen; + __shared__ + Nd4jLong n2; if (threadIdx.x == 0) { - result[0] = (T)math::nd4j_log(math::nd4j_abs(tempRes)); + matrix = reinterpret_cast(output); + inputBuf = reinterpret_cast(input); + inputLen = shape::length(inputShape); + n2 = rowLen * rowLen; + } + __syncthreads(); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int k = pos + start, j = start; j < n2; k += step, j += step) { + auto xIndex = shape::getIndexOffset(k, inputShape, inputLen); + matrix[j] = (F) inputBuf[xIndex]; } } - template - static __global__ void fillMatrix(void* output, Nd4jLong* outShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) { - __shared__ F* matrix; - __shared__ T* inputBuf; - __shared__ Nd4jLong inputLen; - __shared__ Nd4jLong n2; + template + static __global__ void + returnMatrix(void *output, Nd4jLong *outputShape, void *input, Nd4jLong *inputShape, Nd4jLong pos, + Nd4jLong rowLen) { + __shared__ T *matrix; + __shared__ T *outputBuf; + __shared__ Nd4jLong outputLen; + __shared__ Nd4jLong n2; - if (threadIdx.x == 0) { - matrix = reinterpret_cast(output); - inputBuf = reinterpret_cast(input); - inputLen = shape::length(inputShape); - n2 = rowLen * rowLen; - } - __syncthreads(); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int k = pos + start, j = start; j < n2; k += step, j += step) { - auto xIndex = shape::getIndexOffset(k, inputShape, inputLen); - matrix[j] = (F)inputBuf[xIndex]; - } - } - - template - static __global__ void returnMatrix(void* output, Nd4jLong* outputShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) { - __shared__ F* matrix; - __shared__ T* outputBuf; - __shared__ Nd4jLong outputLen; - __shared__ Nd4jLong n2; - - if (threadIdx.x == 0) { - matrix = reinterpret_cast(input); - outputBuf = reinterpret_cast(output); - outputLen = shape::length(inputShape); - n2 = rowLen * rowLen; - } - __syncthreads(); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int k = pos + start, j = start; j < n2; k += step, j += step) { - auto zIndex = shape::getIndexOffset(k, outputShape, outputLen); - outputBuf[zIndex] = (T)matrix[j]; - } - } - - template - static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) { - F* permutation = reinterpret_cast(output); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < rowNum; i += step) { - int val = source[i] - 1; - Nd4jLong posF[] = {i, val}; - auto pos = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), posF, 2); - permutation[pos] = F(1.f); - } - } - - template - static void lup_(LaunchContext* context, NDArray* input, NDArray* compound, NDArray* permutation) { - auto stream = context->getCudaStream(); - auto n = input->rows(); - cusolverDnHandle_t cusolverH = nullptr; - cusolverStatus_t status = cusolverDnCreate(&cusolverH); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("Cannot create cuSolver handle", status); - } - status = cusolverDnSetStream(cusolverH, *stream); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("Cannot set up stream for cuda solver", status); - } - int lwork = 0; - int *d_info = nullptr; - - auto err = cudaMalloc((void **) &d_info, sizeof(int)); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver info buffer", err); - } - - DataType dtype = input->dataType(); - switch(dtype) { - - case DataType::DOUBLE: { - double *d_work = nullptr; - err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err); - } - double *matrix = reinterpret_cast(input->specialBuffer()); - status = cusolverDnDgetrf_bufferSize( - cusolverH, - n, - n, - matrix, - n, - &lwork); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); - } - if (permutation == nullptr) - status = cusolverDnDgetrf( - cusolverH, - n, - n, - matrix, - n, - d_work, - nullptr, - d_info); - else { - NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); - int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); - status = cusolverDnDgetrf( - cusolverH, - n, - n, - matrix, - n, - d_work, - permutationBuf, - d_info); - fillUpPermutation<<>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); - } - err = cudaFree(d_work); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err); - } + if (threadIdx.x == 0) { + matrix = reinterpret_cast(input); + outputBuf = reinterpret_cast(output); + outputLen = shape::length(inputShape); + n2 = rowLen * rowLen; } - break; - case DataType::FLOAT32: { - float *matrix = reinterpret_cast(input->specialBuffer()); - float *d_work = nullptr; - err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err); - } - - status = cusolverDnSgetrf_bufferSize( - cusolverH, - n, - n, - matrix, - n, - &lwork); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); - } - - if (permutation == nullptr) - status = cusolverDnSgetrf( - cusolverH, - n, - n, - matrix, - n, - d_work, - nullptr, - d_info); - else { - NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); - int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); - status = cusolverDnSgetrf( - cusolverH, - n, - n, - matrix, - n, - d_work, - permutationBuf, - d_info); - fillUpPermutation<<>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); - } - err = cudaFree(d_work); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err); - } + __syncthreads(); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (int k = pos + start, j = start; j < n2; k += step, j += step) { + auto zIndex = shape::getIndexOffset(k, outputShape, outputLen); + outputBuf[zIndex] = (T) matrix[j]; } } - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", status); + + template + static __global__ void fillUpPermutation(void *output, Nd4jLong *shape, int *source, int rowNum) { + F *permutation = reinterpret_cast(output); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < rowNum; i += step) { + int val = source[i] - 1; + Nd4jLong posF[] = {i, val}; + auto pos = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), posF, 2); + permutation[pos] = F(1.f); + } } - err = cudaFree(d_info); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); - } - cusolverDnDestroy(cusolverH); + + template + static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { + auto stream = context->getCudaStream(); + auto n = input->rows(); + cusolverDnHandle_t cusolverH = nullptr; + cusolverStatus_t status = cusolverDnCreate(&cusolverH); + defaultContext = context; + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("Cannot create cuSolver handle", status); + } + status = cusolverDnSetStream(cusolverH, *stream); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("Cannot set up stream for cuda solver", status); + } + int lwork = 0; + int *d_info = nullptr; + + auto err = cudaMalloc((void **) &d_info, sizeof(int)); + if (err) { + throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver info buffer", err); + } + + DataType dtype = input->dataType(); + switch (dtype) { + + case DataType::DOUBLE: { + double *d_work = nullptr; + err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); + if (err) { + throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", + err); + } + double *matrix = reinterpret_cast(input->specialBuffer()); + status = cusolverDnDgetrf_bufferSize( + cusolverH, + n, + n, + matrix, + n, + &lwork); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); + } + if (permutation == nullptr) + status = cusolverDnDgetrf( + cusolverH, + n, + n, + matrix, + n, + d_work, + nullptr, + d_info); + else { + NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); + int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); + status = cusolverDnDgetrf( + cusolverH, + n, + n, + matrix, + n, + d_work, + permutationBuf, + d_info); + fillUpPermutation << < n, n, 1024, *stream >> > + (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); + permutation->tickWriteDevice(); + } + err = cudaFree(d_work); + if (err) { + throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", + err); + } + } + break; + case DataType::FLOAT32: { + float *matrix = reinterpret_cast(input->specialBuffer()); + float *d_work = nullptr; + err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); + if (err) { + throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", + err); + } + + status = cusolverDnSgetrf_bufferSize( + cusolverH, + n, + n, + matrix, + n, + &lwork); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); + } + + if (permutation == nullptr) + status = cusolverDnSgetrf( + cusolverH, + n, + n, + matrix, + n, + d_work, + nullptr, + d_info); + else { + NDArray permutVector('c', {n}, nd4j::DataType::INT32, context); + int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); + status = cusolverDnSgetrf( + cusolverH, + n, + n, + matrix, + n, + d_work, + permutationBuf, + d_info); + fillUpPermutation <<< n, n, 128, *stream >> > + (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); + permutation->tickWriteDevice(); + } + err = cudaFree(d_work); + if (err) { + throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", + err); + } + + } + } + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", status); + } + err = cudaFree(d_info); + if (err) { + throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); + } + cusolverDnDestroy(cusolverH); // NDArray::registerSpecialUse({input}, {input}); - input->tickWriteDevice(); - } - BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_NATIVE); - - template - static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); - DataType dtype = input->dataType(); - if (dtype != DataType::DOUBLE) - dtype = DataType::FLOAT32; - - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(256, 256, 1024); - output->assign(1.f); - for (int e = 0; e < output->lengthOf(); e++) { - Nd4jLong pos = e * n2; -// if (matrix.dataType() == input->dataType()) - fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); -// else -// fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); - -// if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); -// else -// lup_(context, &matrix, nullptr, nullptr); - auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); - auto inputBuf = reinterpret_cast(matrix.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; -// if (matrix.dataType() == input->dataType()) - determinantKernel<<>> (inputBuf, outputBuf, n); -// else -// determinantKernel<<>> (inputBuf, outputBuf, n); + input->tickWriteDevice(); } - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } + BUILD_SINGLE_TEMPLATE(template void lup_, + (LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), + FLOAT_NATIVE); - int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } - - template - int logAbsDeterminant_(LaunchContext* context, NDArray* input, NDArray* output) { - - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); - DataType dtype = input->dataType(); - if (dtype != DataType::DOUBLE) - dtype = DataType::FLOAT32; - - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(256, 256, 1024); - output->assign(1.f); - for (int e = 0; e < output->lengthOf(); e++) { - Nd4jLong pos = e * n2; + template + static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + std::vector dims(); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), + {input->rankOf() - 2, input->rankOf() - 1}); + //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); +// DataType dtype = input->dataType(); +// if (dtype != DataType::DOUBLE) +// dtype = DataType::FLOAT32; + defaultContext = context; + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), + defaultContext); //, block.getWorkspace()); + auto det = NDArrayFactory::create(1); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + dim3 launchDims(256, 256, 1024); + output->assign(1.f); + for (int e = 0; e < output->lengthOf(); e++) { + Nd4jLong pos = e * n2; // if (matrix.dataType() == input->dataType()) - fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); + fillMatrix << < launchDims.x, launchDims.y, launchDims.z, *stream >> > + (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); @@ -550,292 +563,406 @@ namespace helpers { lup_(context, &matrix, nullptr, nullptr); // else // lup_(context, &matrix, nullptr, nullptr); - auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); - auto inputBuf = reinterpret_cast(matrix.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; + auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); + auto inputBuf = reinterpret_cast(matrix.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; // if (matrix.dataType() == input->dataType()) - determinantLogKernel<<>> (inputBuf, outputBuf, n); + determinantKernel << < launchDims.x, launchDims.y, launchDims.z, *stream >> > + (inputBuf, outputBuf, n); +// else +// determinantKernel<<>> (inputBuf, outputBuf, n); + } + NDArray::registerSpecialUse({output}, {input}); + + return Status::OK(); + } + + int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { + defaultContext = context; + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); + } + + template + int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { + defaultContext = context; + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + std::vector dims(); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), + {input->rankOf() - 2, input->rankOf() - 1}); + //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); + DataType dtype = input->dataType(); + if (dtype != DataType::DOUBLE) + dtype = DataType::FLOAT32; + + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, + defaultContext); //, block.getWorkspace()); + auto det = NDArrayFactory::create(1); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + dim3 launchDims(256, 256, 1024); + output->assign(0.f); + for (int e = 0; e < output->lengthOf(); e++) { + Nd4jLong pos = e * n2; +// if (matrix.dataType() == input->dataType()) + fillMatrix << < launchDims.x, launchDims.y, launchDims.z, *stream >> > + (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); +// else +// fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); + +// if (matrix.dataType() == input->dataType()) + lup_(context, &matrix, nullptr, nullptr); +// else +// lup_(context, &matrix, nullptr, nullptr); + auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); + auto inputBuf = reinterpret_cast(matrix.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; +// if (matrix.dataType() == input->dataType()) + determinantLogKernel << < launchDims.x, launchDims.y, launchDims.z, *stream >> > + (inputBuf, outputBuf, n); // else // determinantLogKernel<<>> (inputBuf, outputBuf, n); - } - NDArray::registerSpecialUse({output}, {input}); - - return Status::OK(); - - return ND4J_STATUS_OK; - } - - int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } - - template - static __global__ void fillLowerUpperKernel(void* lowerBuf, Nd4jLong* lowerShape, void* upperBuf, Nd4jLong* upperShape, void* matrixBuf, Nd4jLong* matrixShape, Nd4jLong n) { - - __shared__ Nd4jLong* xShapeOf; - __shared__ Nd4jLong* yShapeOf; - __shared__ Nd4jLong* zShapeOf; - __shared__ Nd4jLong* xStrideOf; - __shared__ Nd4jLong* yStrideOf; - __shared__ Nd4jLong* zStrideOf; - __shared__ T* lowerMatrix; - __shared__ T* upperMatrix; - __shared__ T* matrix; - - if (threadIdx.x == 0) { - xShapeOf = shape::shapeOf(lowerShape); - xStrideOf = shape::stride(lowerShape); - - yShapeOf = shape::shapeOf(upperShape); - yStrideOf = shape::stride(upperShape); - - zShapeOf = shape::shapeOf(matrixShape); - zStrideOf = shape::stride(matrixShape); - lowerMatrix = reinterpret_cast(lowerBuf); - upperMatrix = reinterpret_cast(upperBuf); - matrix = reinterpret_cast(matrixBuf); - } - __syncthreads(); - - for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it - for (int j = threadIdx.x; j < n; j += blockDim.x) { - Nd4jLong posX[] = {k, j}; - Nd4jLong posD[] = {j, j}; - auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2); - auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2); - auto iPos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2); - auto dPos = shape::getOffset(0, zShapeOf, zStrideOf, posD, 2); - if (k >= j) - lowerMatrix[xPos] = matrix[iPos];//(k, j); - else - upperMatrix[yPos] = matrix[iPos]; //k, j); } + NDArray::registerSpecialUse({output}, {input}); + + return Status::OK(); + + return ND4J_STATUS_OK; } - } - template - static int inverse_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto dtype = input->dataType(); - if (dtype != DataType::DOUBLE) - dtype = DataType::FLOAT32; - NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); - NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); - NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); - NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); - NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, input->getContext()); - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1}); - auto stream = context->getCudaStream(); - - for (auto i = 0LL; i < packX.numberOfTads(); i++) { - fillMatrix<<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); - matrix.tickWriteDevice(); - compound.assign(matrix); - lup_(context, &compound, nullptr, nullptr); - fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); - matrix.assign(0); - invertUpperMatrix(&upper, &matrix); // U^{-1} - compound.assign(0); - invertLowerMatrix(&lower, &compound); // L{-1} - - nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); - returnMatrix<<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); + int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { + defaultContext = context; + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); } - return Status::OK(); - } - int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } + template + static __global__ void + fillLowerUpperKernel(void *lowerBuf, Nd4jLong *lowerShape, void *upperBuf, Nd4jLong *upperShape, + void *matrixBuf, Nd4jLong *matrixShape, Nd4jLong n) { - bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { - return true; - } + __shared__ + Nd4jLong *xShapeOf; + __shared__ + Nd4jLong *yShapeOf; + __shared__ + Nd4jLong *zShapeOf; + __shared__ + Nd4jLong *xStrideOf; + __shared__ + Nd4jLong *yStrideOf; + __shared__ + Nd4jLong *zStrideOf; + __shared__ + T *lowerMatrix; + __shared__ + T *upperMatrix; + __shared__ + T *matrix; - template - __global__ void fillBatchKernel(F** dArrayBatch, F* buf, Nd4jLong* offsets, Nd4jLong batchSize) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; + if (threadIdx.x == 0) { + xShapeOf = shape::shapeOf(lowerShape); + xStrideOf = shape::stride(lowerShape); - for (auto i = start; i < batchSize; i += step) { - dArrayBatch[i] = buf + offsets[i]; - } - } + yShapeOf = shape::shapeOf(upperShape); + yStrideOf = shape::stride(upperShape); - template - __global__ void adjustResultsKernel(F* dArray, Nd4jLong* shape, Nd4jLong* offsets, Nd4jLong batchSize, Nd4jLong n) { - //auto i = blockIdx.x * blockDim.x + threadIdx.x; - Nd4jLong* shapeOf = shape::shapeOf(shape); - Nd4jLong* strideOf = shape::stride(shape); + zShapeOf = shape::shapeOf(matrixShape); + zStrideOf = shape::stride(matrixShape); + lowerMatrix = reinterpret_cast(lowerBuf); + upperMatrix = reinterpret_cast(upperBuf); + matrix = reinterpret_cast(matrixBuf); + } + __syncthreads(); - for (auto i = blockIdx.x; i < batchSize; i+= gridDim.x) { - auto current = dArray + offsets[i]; - for (auto r = threadIdx.x; r < n; r += blockDim.x) { - for (auto c = r + 1; c < n; c++) { - Nd4jLong posRC[] = {r, c}; - auto pos = r * n + c; //shape::getOffset(0, shapeOf, strideOf, posRC, 2); - current[pos] = 0.; + for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it + for (int j = threadIdx.x; j < n; j += blockDim.x) { + Nd4jLong posX[] = {k, j}; + Nd4jLong posD[] = {j, j}; + auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2); + auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2); + auto iPos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2); + auto dPos = shape::getOffset(0, zShapeOf, zStrideOf, posD, 2); + if (k >= j) + lowerMatrix[xPos] = matrix[iPos];//(k, j); + else + upperMatrix[yPos] = matrix[iPos]; //k, j); } } } - } - template - int cholesky__(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) { - if (!inplace) - output->assign(input); - std::unique_ptr tempOutput(output->dup()); - cusolverDnHandle_t handle = nullptr; - auto n = input->sizeAt(-1); - auto n2 = n * n; - NDArray::prepareSpecialUse({output}, {input}); - auto status = cusolverDnCreate(&handle); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status); - } - F** dArrayBatch = nullptr; - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {tempOutput->rankOf() - 2, tempOutput->rankOf() - 1}); - const Nd4jLong batchSize = packX.numberOfTads(); - int* dInfoArray = nullptr; - auto err = cudaMalloc((void**)&dArrayBatch, sizeof(F*) * batchSize); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer", err); - } - err = cudaMalloc ((void**)&dInfoArray, sizeof(int) * batchSize); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); - } - auto stream = context->getCudaStream(); - fillBatchKernel<<<1, batchSize, 128, *stream>>>(dArrayBatch, reinterpret_cast(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize); + template + static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { + defaultContext = context; + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto dtype = DataTypeUtils::fromT(); //input->dataType(); +// if (dtype != DataType::DOUBLE) +// dtype = DataType::FLOAT32; + NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); + NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); + NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); + NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); + NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), + {input->rankOf() - 2, + input->rankOf() - 1}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), + {output->rankOf() - 2, + output->rankOf() - 1}); + auto stream = context->getCudaStream(); - status = cusolverDnSetStream(handle, *stream); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cannot set stream to solver handle", status); - } - const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; - if (input->dataType() == DataType::DOUBLE) - status = cusolverDnDpotrfBatched( - handle, - uplo, - n, - (double**)dArrayBatch, - n, - dInfoArray, - batchSize); - else - status = cusolverDnSpotrfBatched( - handle, - uplo, - n, - (float**)dArrayBatch, - n, - dInfoArray, - batchSize); - - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status); - } - adjustResultsKernel<<>>(reinterpret_cast(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); - - err = cudaFree(dArrayBatch); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer", err); - } - err = cudaFree(dInfoArray); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); + for (auto i = 0LL; i < packX.numberOfTads(); i++) { + fillMatrix << < 1, n2, 1024, *stream >> > + (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), + i * n2, n); + matrix.tickWriteDevice(); + compound.assign(matrix); + lup_(context, &compound, nullptr, nullptr); + fillLowerUpperKernel << < n, n, 1024, *stream >> > + (lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); + matrix.assign(0); + invertUpperMatrix(&upper, &matrix); // U^{-1} + matrix.tickWriteDevice(); +// matrix.printIndexedBuffer("Upper Inverted"); + compound.assign(0); + invertLowerMatrix(&lower, &compound); // L{-1} + compound.tickWriteDevice(); +// compound.printIndexedBuffer("Lower Inverted"); +// matrix.tickWriteDevice(); +// compound.tickWriteDevice(); + nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); + upper.tickWriteDevice(); +// upper.printIndexedBuffer("Full inverted"); + returnMatrix << < 1, n2, 1024, *stream >> > + (output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), + i * n2, n); + } + return Status::OK(); } - if(!inplace) - output->assign(tempOutput.get()); - else - input->assign(tempOutput.get()); - - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } - -// template - int cholesky_(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) { - NDArray::prepareSpecialUse({output}, {input}); - if (input->dataType() == DataType::DOUBLE) - cholesky__(context, input, output, inplace); - else if (input->dataType() == DataType::FLOAT32) - cholesky__(context, input, output, inplace); - else { - std::unique_ptr tempOutput(NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, input->getContext())); - tempOutput->assign(input); - cholesky__(context, tempOutput.get(), tempOutput.get(), true); - output->assign(tempOutput.get()); + int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { + defaultContext = context; + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); } - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } - int cholesky(nd4j::LaunchContext* context, NDArray* input, NDArray* output, bool inplace) { -// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); - return cholesky_(context, input, output, inplace); - } -// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); - - __global__ void logDetKernel(double* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, double* outputBuf, Nd4jLong* outputShape) { - - __shared__ int n; - if (threadIdx.x == 0) { - n = shape::sizeAt(inputShape, -1); // * shape::sizeAt(inputShape, -1); + bool checkCholeskyInput(nd4j::LaunchContext *context, NDArray const *input) { + return true; } - __syncthreads(); - double* output = outputBuf; - double* input = inputBuf; + template + __global__ void fillBatchKernel(F **dArrayBatch, F *buf, Nd4jLong *offsets, Nd4jLong batchSize) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - Nd4jLong* shapeOf = shape::shapeOf(tadShape); - Nd4jLong* strideOf = shape::stride(tadShape); - - for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { - double* current = input + tadOffsets[i]; - - auto zIndex = shape::getIndexOffset(i, outputShape, batchNum); - for (auto e = threadIdx.x; e < n; e += blockDim.x) { - Nd4jLong diag[] = {e, e}; - auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2); - math::atomics::nd4j_atomicAdd(&output[zIndex], math::nd4j_log(current[xIndex] * current[xIndex])); + for (auto i = start; i < batchSize; i += step) { + dArrayBatch[i] = buf + offsets[i]; } } - } - int logdetFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input}); - auto n2 = input->sizeAt(-1) * input->sizeAt(-2); - auto stream = context->getCudaStream(); - std::unique_ptr tempOutput(input->dup()); + template + __global__ void + adjustResultsKernel(F *dArray, Nd4jLong *shape, Nd4jLong *offsets, Nd4jLong batchSize, Nd4jLong n) { + //auto i = blockIdx.x * blockDim.x + threadIdx.x; + Nd4jLong *shapeOf = shape::shapeOf(shape); + Nd4jLong *strideOf = shape::stride(shape); + + for (auto i = blockIdx.x; i < batchSize; i += gridDim.x) { + auto current = dArray + offsets[i]; + for (auto r = threadIdx.x; r < n; r += blockDim.x) { + for (auto c = r + 1; c < n; c++) { + Nd4jLong posRC[] = {r, c}; + auto pos = r * n + c; //shape::getOffset(0, shapeOf, strideOf, posRC, 2); + current[pos] = 0.; + } + } + } + } + + template + int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { + if (!inplace) + output->assign(input); + defaultContext = context; + std::unique_ptr tempOutput(output->dup()); + cusolverDnHandle_t handle = nullptr; + auto n = input->sizeAt(-1); + auto n2 = n * n; + NDArray::prepareSpecialUse({output}, {input}); + auto status = cusolverDnCreate(&handle); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status); + } + F **dArrayBatch = nullptr; + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), + {tempOutput->rankOf() - 2, + tempOutput->rankOf() - 1}); + const Nd4jLong batchSize = packX.numberOfTads(); + int *dInfoArray = nullptr; + auto err = cudaMalloc((void **) &dArrayBatch, sizeof(F *) * batchSize); + if (err) { + throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer", + err); + } + err = cudaMalloc((void **) &dInfoArray, sizeof(int) * batchSize); + if (err) { + throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); + } + auto stream = context->getCudaStream(); + fillBatchKernel << < 1, batchSize, 128, *stream >> > + (dArrayBatch, reinterpret_cast(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize); + + status = cusolverDnSetStream(handle, *stream); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::cholesky_: Cannot set stream to solver handle", status); + } + const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + if (input->dataType() == DataType::DOUBLE) + status = cusolverDnDpotrfBatched( + handle, + uplo, + n, + (double **) dArrayBatch, + n, + dInfoArray, + batchSize); + else + status = cusolverDnSpotrfBatched( + handle, + uplo, + n, + (float **) dArrayBatch, + n, + dInfoArray, + batchSize); + + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status); + } + adjustResultsKernel << < batchSize, n2, 128, *stream >> > + (reinterpret_cast(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); + + err = cudaFree(dArrayBatch); + if (err) { + throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer", + err); + } + err = cudaFree(dInfoArray); + if (err) { + throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); + } + + if (!inplace) + output->assign(tempOutput.get()); + else + input->assign(tempOutput.get()); + + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); + } + +// template + int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { + defaultContext = context; + NDArray::prepareSpecialUse({output}, {input}); + if (input->dataType() == DataType::DOUBLE) + cholesky__(context, input, output, inplace); + else if (input->dataType() == DataType::FLOAT32) + cholesky__(context, input, output, inplace); + else { + std::unique_ptr tempOutput( + NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, + defaultContext)); + tempOutput->assign(input); + cholesky__(context, tempOutput.get(), tempOutput.get(), true); + output->assign(tempOutput.get()); + } + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); + } + + int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { +// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); + defaultContext = context; + return cholesky_(context, input, output, inplace); + } +// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext * context, NDArray * input, NDArray * output), + FLOAT_NATIVE); + + template + __global__ void + logDetKernel(T *inputBuf, Nd4jLong *inputShape, Nd4jLong batchNum, Nd4jLong *tadShape, Nd4jLong *tadOffsets, + T *outputBuf, Nd4jLong *outputShape) { + + __shared__ int n; + if (threadIdx.x == 0) { + n = shape::sizeAt(inputShape, -1); // * shape::sizeAt(inputShape, -1); + } + __syncthreads(); + + T *output = outputBuf; + T *input = inputBuf; + + Nd4jLong *shapeOf = shape::shapeOf(tadShape); + Nd4jLong *strideOf = shape::stride(tadShape); + + for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { + T *current = input + tadOffsets[i]; + + auto zIndex = shape::getIndexOffset(i, outputShape, batchNum); + for (auto e = threadIdx.x; e < n; e += blockDim.x) { + Nd4jLong diag[] = {e, e}; + auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2); + math::atomics::nd4j_atomicAdd(&output[zIndex], + math::nd4j_log(current[xIndex] * current[xIndex])); + } + } + } + + template + int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { + defaultContext = context; + NDArray::prepareSpecialUse({output}, {input}); + auto n2 = input->sizeAt(-1) * input->sizeAt(-2); + auto stream = context->getCudaStream(); + std::unique_ptr tempOutput(input->dup()); // auto inputs = tempOutput->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); // for (Nd4jLong e = 0; e < packX.numberOfTads(); e++) { // auto subArray = inputs->at(e); // cholesky(context, subArray, subArray, true); // } // delete inputs; - cholesky(context, input, tempOutput.get(), false); - tempOutput->syncToHost(); - tempOutput->printIndexedBuffer("Cholesky res!!!"); - auto outputBuf = reinterpret_cast(output->specialBuffer()); // + e * n2; // + e * n2; - auto inputBuf = reinterpret_cast(tempOutput->specialBuffer()); - output->assign(0); - output->syncToDevice(); - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - logDetKernel<<>>(inputBuf, tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo()); + cholesky(context, input, tempOutput.get(), false); + tempOutput->syncToHost(); + tempOutput->printIndexedBuffer("Cholesky res!!!"); + auto outputBuf = reinterpret_cast(output->specialBuffer()); // + e * n2; // + e * n2; + auto inputBuf = reinterpret_cast(tempOutput->specialBuffer()); + output->assign(0); + output->syncToDevice(); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), + {input->rankOf() - 2, + input->rankOf() - 1}); + logDetKernel << < packX.numberOfTads(), n2, 128, *stream >> > + (inputBuf, tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo()); // } - NDArray::registerSpecialUse({output}, {input}); - //delete tempOutput; - return Status::OK(); + NDArray::registerSpecialUse({output}, {input}); + //delete tempOutput; + return Status::OK(); + } + + int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { + defaultContext = context; + BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE); + } + + BUILD_SINGLE_TEMPLATE(template int logdetFunctor_, + (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE); } } } -} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu b/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu index 94464bbbc..b131ff83f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu @@ -33,19 +33,19 @@ __global__ static void zetaCuda(const void *vx, const Nd4jLong *xShapeInfo, const auto x = reinterpret_cast(vx); const auto q = reinterpret_cast(vq); - auto z = reinterpret_cast(vz); + auto z = reinterpret_cast(vz); __shared__ Nd4jLong len; - - if (threadIdx.x == 0) - len = shape::length(xShapeInfo); + + if (threadIdx.x == 0) + len = shape::length(xShapeInfo); __syncthreads(); const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto totalThreads = gridDim.x * blockDim.x; for (int i = tid; i < len; i += totalThreads) { - + const auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); const auto qOffset = shape::getIndexOffset(i, qShapeInfo, len); const auto zOffset = shape::getIndexOffset(i, zShapeInfo, len); @@ -65,10 +65,10 @@ void zeta(nd4j::LaunchContext * context, const NDArray& x, const NDArray& q, NDA if(!x.isActualOnDeviceSide()) x.syncToDevice(); if(!q.isActualOnDeviceSide()) q.syncToDevice(); - - int threadsPerBlock = MAX_NUM_THREADS; + + int threadsPerBlock = MAX_NUM_THREADS / 2; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - + BUILD_SINGLE_SELECTOR(x.dataType(), zetaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), q.getSpecialBuffer(), q.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); x.tickReadHost(); diff --git a/libnd4j/tests_cpu/layers_tests/AveragingArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/AveragingArrayTests.cpp deleted file mode 100644 index 06eb6c4a1..000000000 --- a/libnd4j/tests_cpu/layers_tests/AveragingArrayTests.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "testlayers.h" -#include -#include - -using namespace nd4j; -using namespace nd4j::ops; -using namespace nd4j::graph; - -class AveragingArrayTests : public testing::Test { -public: - -}; - -TEST_F(AveragingArrayTests, test_basic_reads_1) { - auto exp0 = NDArrayFactory::create('c', {1, 5},{3.0, 3.0, 3.0, 3.0, 3.0}); - - auto original = NDArrayFactory::create('c', {100, 5}); - original.assign(1.0); - - AveragingArrayProxy proxy(&original); - - auto writeable0 = proxy.writeable(1, 0); - auto writeable1 = proxy.writeable(1, 1); - auto writeable2 = proxy.writeable(2, 1); - - ASSERT_FALSE(writeable0 == nullptr); - ASSERT_FALSE(writeable1 == nullptr); - ASSERT_FALSE(writeable2 == nullptr); - - writeable0->assign(2.0); - writeable1->assign(4.0); - writeable2->assign(3.0); - - auto r = proxy.collapseWrites(); - - ASSERT_TRUE(r); - - auto row1 = original({1,2, 0,0}, true); - auto row2 = original({2,3, 0,0}, true); - - ASSERT_EQ(exp0, row1); - ASSERT_EQ(exp0, row2); -} diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 987206a14..d8cf86495 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -159,9 +159,9 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, - 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -2211,55 +2211,6 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) { delete results; } -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, upsampling2d_bp_test1) { - - const int bS=1, iH=2,iW=2, iC=1; - const int factorH=2, factorW=2; - const int isNCHW = 1; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}); - gradO = 1.; - - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); - expGradI = 4.; - - nd4j::ops::upsampling2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, upsampling2d_bp_test2) { - - const int bS=1, iH=2,iW=2, iC=1; - const int factorH=2, factorW=2; - const int isNCHW = 0; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}); - gradO = 1.; - - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC}); - expGradI = 4.; - - nd4j::ops::upsampling2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - delete results; -} ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_bp_test1) { @@ -2315,6 +2266,70 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, upsampling3d_bp_test3) { + + const int bS=1, iD=3,iH=3,iW=3, iC=2; + const int factorD=2, factorH=2, factorW=2; + const int isNCDHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32); + NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, + 0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, + 0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, + 0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, + 0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, + 0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, + 0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227, + 0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047, + 0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033, + 0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843, + 0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876, + 0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908, + 0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415, + 0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304, + 0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846, + 0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239, + 0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083, + 0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075, + 0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565, + 0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615, + 0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824, + 0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565, + 0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802, + 0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821, + 0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676, + 0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527, + 0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158, + 0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731, + 0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452, + 0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176, + 0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142, + 0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622, + 0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457, + 0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804, + 0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821, + 0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333, + 0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, nd4j::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278, + 3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016, + 4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917, + 4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856, + 4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32); + + nd4j::ops::upsampling3d_bp op; + auto results = op.execute({&input, &gradO}, {}, {isNCDHW}); + auto* gradI = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + delete results; +} + + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index c1b34d779..a27c67fc4 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -152,18 +152,18 @@ TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) { TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) { auto input0 = NDArrayFactory::create('c', {4}, {3, 8, 8, 16}); - auto input1 = NDArrayFactory::create('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f, --1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, --1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f, -2.19538260f, -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, 2.33141375f, 1.10972047f, 0.71407092f, -1.70640314f, 1.80666339f, 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, -0.13715318f, --0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f, -1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f, -0.97053039f, -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, 0.95460182f, 1.59319806f, -0.44433126f, --0.33717924f, 0.79566282f, 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, 0.15854552f, -0.67599791f, -0.53726906f, -1.20158446f, -1.78549063f, 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, -0.10575818f, --1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f, --1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f, + auto input1 = NDArrayFactory::create('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f, +-1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, +-1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f, +2.19538260f, -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, 2.33141375f, 1.10972047f, 0.71407092f, +1.70640314f, 1.80666339f, 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, -0.13715318f, +-0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f, +1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f, +0.97053039f, -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, 0.95460182f, 1.59319806f, -0.44433126f, +-0.33717924f, 0.79566282f, 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, 0.15854552f, -0.67599791f, +0.53726906f, -1.20158446f, -1.78549063f, 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, -0.10575818f, +-1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f, +-1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f, -0.73649710f, 0.59814465f, -0.86736494f, -0.32200798f, -1.28087902f, -0.76818323f, 0.86848933f, -0.98678392f, -1.30813944f, -0.20255326f, 0.26557815f, -0.31090519f, -1.46331608f, -0.62782109f, 0.59034890f, 1.63147473f, -0.17727259f, -0.37636510f, 1.27368402f, 0.19096918f, -0.29936951f, -1.99038267f, 0.54831523f, 0.48849005f, -2.55680346f, -0.63126534f, 1.21715927f, 1.22841084f, -0.67416084f, 0.02927168f, -0.36693662f, 0.63204330f, 0.13721083f, 0.28742912f, 0.19470036f, 0.74873924f, -1.47602463f, 0.86264688f, -0.23730527f, -0.99978864f, -1.17048764f, -0.34996086f, 1.43019187f, 0.26224539f, 0.60689932f, -0.75002515f, -0.79823422f, -1.37300086f, -0.19951135f, -0.12150808f, -0.75272322f, 0.23755015f, 0.31270382f, 1.66539109f, -1.04104745f, 0.79540199f, -0.54042423f, -0.54150617f, 0.43871084f, 0.24163951f, -0.24517761f, -0.66178995f, -1.13064528f, -0.84426326f, 0.56437236f, 0.09088907f, -0.82823074f, 0.81753862f, -1.74096012f, -1.80599844f, -0.60943592f, 1.36094582f, -1.47762752f, 0.15931177f, 1.05569172f, 0.36751524f, 0.06497604f, 0.13536447f, -1.57156146f, 0.22783801f, -0.96910107f, -1.24294984f, -1.47147155f, -1.04790676f, 0.64629447f, -0.32266054f, -0.55675793f, -0.95612079f, -0.23005411f, -0.75229394f, 0.03050950f, -1.72484553f, -2.06055546f, 0.19892083f, -0.13597751f, 0.65180075f, 0.27096850f, 0.08977254f, 0.57564765f, -0.43227410f, 0.09541437f, -0.00358280f, 0.65680492f, 0.04006556f, 0.57160908f, 0.43821687f, 1.96118212f, 0.42602235f, -0.36731303f, 0.67200917f, -0.56667900f, 0.44014785f, 0.06970236f, -1.34415269f, -1.13301528f, -0.08848868f, 0.35615012f, -0.06426942f, -0.81406075f, 0.94097465f, -0.54560357f, -0.65877116f, -1.29646838f, -1.13109028f, -1.64186084f, -2.12723470f, 1.86027610f, 1.22621441f, 0.26098135f, -0.05608099f, 0.21143445f, -0.87244326f, 0.79408187f, 1.24279130f, 0.14458629f, 0.25532281f, -1.24023473f, 2.42278886f, 0.00405578f, -1.00119174f, 1.19856644f, -1.37395728f, -0.16656208f, 0.46858498f, -0.00678801f, -0.34960639f, 0.16614936f, 2.41560221f, -0.53880709f, 0.91618651f, -1.77009308f, 0.32911557f, 0.30216452f, 0.02881077f, 0.77705866f, 0.27061903f, -0.07440855f, -1.14010465f, 1.25383139f, -1.58615100f, 1.04185510f, 0.15140508f, -0.88059032f, -0.33872122f, -0.42526904f, 2.17365575f, 0.29308075f, -2.24234557f, -1.03164542f, -0.09263755f, 0.08050421f, -0.74946511f, -0.64589006f, -1.13416314f, -0.64989561f, 0.16502371f, -0.33831969f, 0.22832428f, -0.08389475f, -0.28009200f, 1.34536922f, -0.19075738f, 0.36238208f, 0.83690089f, 0.26144615f, 0.04457319f, -2.55585861f, -0.01807522f, 1.68334866f, -0.05795629f, -0.21315987f, -1.84039557f, 0.06512877f, -1.77318645f, -0.27637982f, 0.20439345f, 0.67558700f, -0.77179354f, -0.17902173f, 0.70381826f, -0.40395790f, -0.96492916f, 0.84138173f, 2.43879008f, -0.32297835f, -1.74370265f, -0.10330839f, -1.07465363f, 1.85030377f, -0.59153467f, 0.99667048f, -0.56753993f, 0.57383025f, -1.90630126f, 1.24299097f, 0.22797665f, 0.30468231f, -0.07360230f, 1.64654350f, 0.57195550f, 0.03227921f, 1.11005175f, 0.00088721f, 1.19266295f, 0.61323351f, 0.13754399f, 0.59900171f, -0.75831634f, 1.11500823f, 0.99747783f, -1.36923385f, 1.26563418f, 0.01253266f, 0.35483193f, 1.95143735f, -2.02703261f, -1.38265920f, -0.02404256f, 2.02788448f, -0.75144875f, -0.58445263f, 0.26129767f, 0.60691077f, -1.84661067f, 0.65872228f, -0.58298993f, 0.33067298f, -0.09431327f, 0.43333948f, -1.52616286f, -0.25961858f, -1.65459549f, -0.72950101f, -0.89906919f, -0.80081612f, -1.32189929f, -1.36574399f, -0.35809481f, 0.36385000f, 0.31480747f, -0.35797358f, -1.04066050f, 0.07971872f, -0.21176252f, -0.76559299f, -0.10352154f, 0.29248312f, -1.75030553f, 0.68219930f, 0.56189102f, -1.11212170f, 0.06501702f, -0.07131009f, 1.23410738f, 0.29311740f, -1.02052307f, 1.40220940f, -1.00995779f, 0.57955760f, 0.22640309f, 0.74853230f, -0.02586563f, -0.33427954f, 1.70311153f, -0.53405988f, 0.90975094f, -0.46450076f, 0.19904344f, 0.28559047f, 0.23167793f, -0.69065529f, -0.17176504f, -0.29301846f, -0.85477978f, -0.00267053f, -0.28529504f, -0.64201307f, 1.03479636f, 1.03805065f, 0.83270210f, -0.09405448f, 2.50615931f, 0.62019676f, 0.31354564f, -1.51599669f, 0.42848015f, 0.66263914f, 0.74651009f, -1.13042867f, -0.58933645f, -0.35146511f, 0.06223279f, 0.28065836f, 0.66506970f, 0.16942430f, -0.23316263f, -0.87481076f, 1.21992230f, 1.48536301f, -0.79667616f, -0.75519305f, 1.40999961f, -0.42802793f, -0.20252463f, 0.30573779f, -0.23319976f, 1.77525878f, -1.80704832f, 2.71519923f, -0.67500192f, 0.12268137f, -0.13014549f, -0.07479453f, -1.51065743f, 1.04198146f, 0.96205556f, -2.00525570f, -0.37911776f, 0.89329720f, -0.39495832f, -0.03683375f, -0.90928614f, -1.56263304f, 0.45038295f, -2.62184358f, -0.45686841f, -0.52536523f, 1.05351484f, 0.89982438f, -0.63724512f, 3.21004057f, -0.08608918f, 1.55209303f, 0.62688643f, -0.59702635f, 1.85774517f, 0.38172096f, -1.25640929f, -2.59278178f, 0.85050315f, -1.10080361f, -1.26422560f, -1.80045366f, -0.34494889f, 0.68448657f, 1.25671864f, -1.26594126f, 0.32244179f, -0.51956522f, -0.56212711f, -0.95574015f, 0.71973872f, 0.46736258f, -0.11772985f, -1.52736545f, 0.19571695f, 0.73147154f, 0.87724912f, -0.26265728f, -2.60267401f, 0.19263546f, 0.18320183f, 0.11485019f, -0.82999659f, 0.13582672f, -0.08040185f, 0.28152901f, -0.51421624f, -2.32467175f, 0.19923948f, 0.64616692f, 0.29718629f, 0.32785949f, -0.62266952f, -0.98174316f, 1.23276305f, 0.58563638f, 1.28528512f, -2.13718534f, 0.28842899f, 0.12676710f, -1.72105229f, 0.15053287f, 2.19496536f, 1.28683448f, -0.96318281f, 0.17043279f, -0.05245409f, -0.38710704f, -0.30441490f, -0.08249986f, 0.28423953f, 0.72963721f, -1.49658203f, 0.99077344f, -0.78913772f, -1.12661564f, -1.26294816f, 0.16517465f, 0.10124251f, -0.77198768f, -0.16342169f, 0.08615876f, 0.49711797f, -0.66083062f, 0.76648003f, 1.04756033f, 1.46122825f, -0.42798752f, -2.29203916f, 0.30444992f, 0.58697921f, 1.22166932f, 0.09022947f, -0.03920181f, 0.10444995f, 0.10361757f, 1.18224072f, -0.76641631f, 0.90802073f, 1.41639423f, 1.55682337f, 1.28101575f, -0.35396016f, 1.11443567f, 1.18218529f, -0.06048089f, 0.85024464f, -1.01789165f, -0.69154263f, 0.06663221f, 0.68429029f, 0.12560424f, 0.37915874f, -0.66829866f, -0.64524972f, -0.05568011f, 0.12230454f, -0.35041061f, 0.62027830f, -0.16739209f, -0.72145337f, 0.46263054f, -1.67837834f, 0.69413221f, -0.57243419f, 0.37638462f, -0.21446526f, -0.89821470f, 0.60078722f, -1.06706369f, -1.26132309f, 0.35714921f, 2.39221811f, -0.09376130f, 0.30760849f, 0.59180892f, 0.55815399f, -0.32628775f, 1.28890121f, -2.53237987f, -0.98241091f, 1.10520673f, -1.74751687f, -0.90837651f, -0.25220659f, -0.56625104f, -0.30691949f, 0.16058689f, 0.44309673f, -1.09874964f, -0.76747823f, -0.33679363f, -0.02535496f, 0.00990100f, 1.35318136f, -0.70140815f, 0.50937581f, 0.55386209f, -1.21721983f, 0.71376961f, -0.18079315f, -0.11077732f, 0.09292522f, -0.57235324f, 0.62748206f, 0.42587611f, 0.64860481f, -1.10635614f, 1.66414368f, 0.47505483f, 1.48602211f, -0.59611166f, -0.41932896f, -0.96542233f, -0.41756630f, -1.02963889f, -0.70070386f, 1.65803933f, 0.20138647f, 0.05895034f, -1.46152759f, -0.37278318f, 1.05535650f, 0.34437978f, -1.13257408f, 0.17635690f, 0.09386671f, 0.37079874f, 1.47695887f, -1.58420062f, -0.26100200f, 0.44847637f, 0.88847303f, -0.13877590f, -0.64620668f, -0.38019657f, 1.01608157f, 0.13357787f, 0.05137976f, 0.93498152f, -0.62226880f, 0.80461699f, -0.71682596f, -0.88756353f, 0.40933055f, -1.52167451f, 0.79756850f, -0.17307425f, 0.62368619f, -0.22466940f, -1.72802913f, 0.59047443f, -0.58020931f, 0.09096476f, -0.07317388f, 0.44522321f, -0.64880705f, 0.15684015f, 0.08708375f, -0.41556796f, 1.11579072f, -0.81733495f, 0.11643656f, -0.73995101f, 0.93685871f, 1.57971406f, 0.67606360f, 0.70509088f, -0.25283816f, -0.00010609f, -0.61884147f, -0.86409342f, 0.95383751f, -0.05895388f, -1.45261180f, 0.45166013f, -1.01434863f, 0.18496066f, 1.06517637f, 1.81127059f, 0.89470667f, -0.13232610f, 0.46958798f, 0.13884509f, 0.57117194f, 0.29575035f, -0.97884250f, 0.83291447f, -0.59255791f, -0.04354135f, -0.19431923f, 0.30071029f, -0.95421529f, 0.76359886f, -0.47799742f, 0.68254346f, 1.19368529f, -0.48935115f, 0.30357337f, -0.50225669f, -0.23370270f, 1.96702433f, 1.46558523f, 2.68482018f, 0.41622332f, 0.73697484f, 1.43430734f, 0.15387188f, 0.20875402f, -2.49335337f, -1.39674246f, -0.22125854f, -0.00424605f, 0.91416460f, 0.33384630f, 0.44703746f, 0.25610185f, 0.38966551f, -0.01784045f, 1.66148460f, 0.36005461f, 0.95716912f, -0.18246566f, -0.15480693f, 0.38775176f, -0.56969136f, -0.29644895f, -1.04565966f, -1.00455630f, 0.30897698f, -1.46885884f, 0.03657720f, -0.49302089f, 1.34134722f, 0.01673754f, 1.22725964f, 0.55256772f, 0.63803208f, -0.29041430f, 1.11455286f, 0.76329172f, 0.27073982f, 0.77173829f, -1.79884446f, -0.11889492f, -1.92040312f, -0.46382675f, 0.20078070f, -0.98889589f, 1.46711135f, -1.68280172f, -0.52852470f, 0.66245162f, 0.29575166f, 1.34826505f, -0.22362417f, -0.14345661f, -2.34815073f, 1.26572001f, 0.66505629f, 1.01141500f, 1.08030057f, 0.17036134f, 0.00168786f, -0.37282917f, 0.69206375f, 1.07367527f, -0.49708191f, 1.49504781f, 0.58224988f, 0.96593714f, -1.07661915f, 0.25202179f, 0.25531644f, 0.42357162f, -0.31236249f, 0.48383278f, -0.06361829f, 0.24131298f, -0.95695931f, -0.12589653f, 0.36134180f, 3.20266032f, -0.40879184f, -0.66985190f, 1.51674330f, 0.34072638f, 1.15076303f, -0.40199137f, 0.46223637f, -0.48608047f, 0.99119538f, -0.22506073f, 0.30968750f, 0.64210880f, 0.54640514f, 0.18607031f, 1.26293361f, -0.77960914f, 0.79572529f, 1.01936150f, 2.27160740f, -1.48034489f, 0.74466604f, 0.14863680f, 0.31102443f, -1.15673816f, -0.38609681f, -2.65026069f, -0.45524642f, -0.74022961f, 2.74991131f, 0.00103815f, -3.03303242f, -0.41556966f, -0.87103498f, 0.78306234f, -0.88195556f, -0.77297026f, 1.21203196f, -1.09754920f, -0.03556008f, -0.31546223f, 0.72954375f, 0.25251788f, 0.11378583f, 0.50921023f, 0.30301905f, -1.60631680f, 0.27152416f, 1.17342317f, -0.70891970f, -0.08392961f, 0.92137378f, -0.10568139f, -0.31653777f, -0.28878728f, 1.22166574f, 1.12693942f, -0.21325994f, 0.94010323f, 1.21796405f, -0.68866694f, 2.30724216f, 0.28141466f, 0.83481526f, -0.04885862f, 0.01675143f, 1.04355800f, -0.81050140f, 1.51300573f, 0.53429186f, -0.56439877f, 0.38572624f, -0.05620475f, 0.67644542f, 0.72528905f, 0.05937041f, -1.06315899f, -0.51393986f, 0.46937627f, -0.34699562f, -0.64765716f, -1.45512629f, 0.47739139f, -0.88228017f, -2.00791359f, 1.29929042f, 0.05482405f, -0.66725296f, -0.54735124f, 0.09972951f, 0.76675093f, 0.98748523f, 0.08900899f, -0.78854066f, 1.47970486f, -0.61667502f, 0.45625573f, -0.21766303f, -0.46250847f, -0.07130960f, 0.64414692f, 0.12784545f, 0.26393634f, 1.07720757f, -1.23938286f, 0.62483376f, -0.55001754f, -0.05358591f, 0.07322436f, 1.12003291f, -1.00830650f, -0.20486419f, 0.76664752f, 0.28850746f, -0.04464776f, -0.40146068f, 0.73262817f, -1.12827921f, -0.19989438f, -1.15999687f, 1.37973154f, 0.78881019f, -0.34762639f, 1.22088552f, -1.64088547f, 0.63218033f, 0.45736769f, 0.05502866f, 2.22683382f, -1.78935897f, -1.49635041f, 0.83450896f, 1.67770112f, 1.33909333f, 1.51158953f, 0.28595078f, -0.08593627f, 0.45812801f, -0.15193029f, 1.14770603f, -0.88920450f, -1.96352005f, -1.49894583f, 0.49629962f, 1.59872091f, 0.00903497f, 2.15563583f, 2.25149560f, -2.01200557f, 2.56229877f, -1.38850498f, 0.73552012f, -0.39378855f, 0.52616280f, -0.03685786f, 0.87403935f, 0.12163408f, 0.74297994f, -0.30697080f, 0.38139752f, 0.49113834f, -0.95485127f, -0.99908817f, 0.71716321f, 0.04000283f, -2.09645271f, 1.38789880f, 1.37198520f, 0.82493287f, 0.17114936f, 0.53696346f, -0.19516060f, -0.50377476f, -0.91730285f, -0.70113552f, -0.02406530f, 0.84943396f, -0.17428185f, -1.09140801f, -0.68156958f, 1.70756388f, -1.00399911f, 0.03023832f, -0.39023280f, -1.89737976f, 1.14469039f, -0.58337289f, -0.60037899f, -1.17490256f, -1.56342828f, 0.48714057f, 0.62266618f, -0.15967095f, 1.32789338f, -1.25700688f, -0.55633998f, -0.83128709f, -0.49346271f, 1.59561753f, -0.24675299f, 0.38012561f, 0.91796309f, -0.38522810f, -0.65509188f, 0.94100451f, -0.57324487f, 2.19070768f, 1.24058700f, -0.75978851f, -0.40460554f, 0.79189235f, 0.70192885f, 1.93569362f, -0.03070199f, 0.77010989f, 0.58794290f, 0.51087004f, 0.22892070f, 0.35007235f, 1.56023848f, -0.67453802f, -0.18485607f, 0.64349502f, -0.31489357f, -1.95834625f, 0.06560058f, 2.30394220f, 1.18194163f, -0.88034087f, -1.05000436f, -1.05471325f, -0.98481798f, 0.49904808f, 0.16438948f, -1.10297823f, -1.39736509f, 0.01306054f, -1.85160267f, -0.87292641f, -0.15418227f, 0.43412164f, 1.16518164f, 0.06273691f, 0.24659210f, -0.08267246f, 1.28885782f, 0.73575675f, -0.01019809f, -0.08753663f, -0.61827368f, -0.40863234f, 2.12599611f, -0.53620332f, 0.53789747f, -0.66386080f, -1.70461988f, 0.86608189f, -1.11151052f, 0.14120635f, 1.18858743f, -0.31760478f, -0.73533046f, 0.20978074f, -0.84074509f, 0.16523147f, -1.03362834f, 0.59721231f, 0.21318658f, 0.23671274f, 1.75115061f, 0.25363782f, -1.32541454f, 1.13056135f, 0.24652456f, 0.60381413f, 0.21478581f, 0.75044096f, -0.63125616f, -1.69889998f, -0.02116571f, 1.46165359f, 1.03068244f, 0.63693464f, 0.67795700f, 1.20033514f, -1.39205134f, -0.61743122f, 0.56549704f, 0.65182322f, -0.74250507f, -1.61939359f, 1.14054918f, -0.45725963f, 1.74519682f, -0.66251940f, -0.94811529f, -1.60865819f, -0.59968346f, 0.86309159f, -1.91936195f, -1.02646923f, -1.50352538f, 0.58292735f, 0.05320299f, 1.53582895f, 0.01069612f, 0.15226212f, -0.71840125f, -1.36896348f, 2.14600968f, 0.96626586f, -0.52014917f, 0.41001406f, 0.59478027f, 0.15282436f, 0.27790198f, 0.76614654f, -0.38971323f, -0.01839927f, -1.57882118f, 0.61391610f, -0.62133092f, -0.03968323f, -0.88467252f, -1.24041140f, 2.07306671f, -0.41776338f, 0.14537935f, -0.91069067f, 1.67362070f, 4.72630215f, -0.07395106f, 0.46280116f, -0.40843824f, 0.70683080f, -0.27510864f, -0.63465804f, -0.83630908f, -0.44419941f, 0.60405648f, -0.65039170f, -1.02413189f, 1.05983019f, 1.73366308f, 0.73343736f, -0.00895882f, -1.00826013f, 0.17323074f, 0.73995626f, 0.24128854f, 0.94510227f, 0.25557515f, 0.02244723f, -0.95197725f, -0.16297856f, -0.38497585f, 1.17993331f, 1.20282137f, -1.31491220f, 0.44229278f, -0.24349044f, -0.01230415f, 1.37944865f, 0.48554277f, -0.54510897f, -0.10793537f, 0.41121426f, -0.12889031f, 0.26434359f, 1.27966082f, 0.64518744f, -0.15577169f, -0.99864733f, -0.61746484f, 2.01614976f, 1.56254935f, 1.86473298f, -0.54662132f, -0.22047071f, -0.06118120f, 0.84799510f, 0.17009684f, -1.30523121f, 0.64000309f, 0.36299205f, -0.59620583f, 1.36372304f, -0.05389515f, -0.93849313f, 0.98043185f, -0.39373067f, -0.84898937f, 1.32077873f, 1.05988657f, -1.35339200f, 0.23259017f, 0.63816410f, -0.80297333f, 0.60017115f, 1.25715804f, 1.18894124f, -0.62473553f, 1.05611980f, 0.02335166f, 1.07509828f, 0.25873449f, -1.68341100f, 0.54547334f, 0.79288185f, -0.93678916f, 0.19202201f, -1.48575914f, 1.08649087f, 0.50851744f, -0.45758674f, -0.39734635f, 0.35637981f, -1.63079453f, -0.75910008f, 0.92640859f, -0.55599529f, -0.40276715f, 0.31307653f, 0.39907026f, -1.18830419f, 0.71051043f, 0.14157933f, -0.39581308f, -1.64361024f, -0.06161860f, -0.25312796f, 1.10018682f, 0.56500763f, 0.80385065f, 0.35395023f, 0.81813669f, 0.27644628f, 0.65563256f, 1.73197234f, 0.68178749f, 0.76769936f, 0.44597456f, 0.67761195f, 0.67635447f, -0.32315412f, 0.19330767f, -0.25557944f, 1.91693723f, 0.38335562f, 0.07107610f, -0.57384586f, 0.79184365f, 1.87835479f, 0.60902315f, -0.94220877f, 0.79479855f, -0.25656971f, 0.08739131f, 0.53384244f, 1.22159266f, -0.39152125f, -1.46373534f, -0.02458516f, 1.62825716f, -1.26112676f, 0.19967082f, -0.71114451f, 0.27929229f, 0.65001321f, -0.11868202f, -0.55587751f, 0.78069001f, 0.57969242f, -0.60274386f, 0.31650013f, 0.90339553f, 0.09453616f, -0.37119162f, -1.00320566f, 0.33299938f, -0.48636708f, 0.26342997f, -0.91914523f, 0.28682709f, -1.24780893f, -1.59254742f, 0.97176319f, 0.14744301f, -0.53056234f, -1.73221612f, -0.67645556f, 0.98705006f, 0.79895812f, -2.04333115f, -0.60132772f, -0.91653955f, -0.28094748f, 0.47943443f, 0.38157779f, -0.67648011f, 1.09093642f, 1.66012859f, -0.29358891f, -1.26773024f, 0.36747769f, -1.10141146f, 0.82383633f, -0.89772314f, -0.47145563f, 0.63939518f, -0.64430422f, -0.48889321f, -0.37680882f, -1.06962025f, -1.28689516f, 1.28365147f, 0.61859220f, -0.84676331f, 1.38404000f, 1.21053445f, -0.14871351f, 1.06349385f, 1.45878971f, -0.47362664f, 1.40707004f, 1.25224137f, 0.87364739f, 0.92858213f, 0.00157326f, 1.45661485f, -0.27318576f, 0.15482858f, -1.07058907f, -0.06903186f, -0.74147576f, -1.64111829f, -0.67226541f, -1.13458407f, 1.28511488f, -0.41041154f, 2.09085560f, 0.45243183f, -0.67437285f, 0.84960121f, -1.49300814f, -0.42961186f, -2.35021853f, 0.57255560f, -0.73903763f, 1.37607956f, -2.44575167f, 1.25105727f, 1.38575912f, -1.16299784f, -0.13719854f, -1.11507034f, 0.35796806f, -0.64511567f, -0.87903833f, 0.32833642f, -0.87696886f, 0.02714214f, 0.30224666f, -0.69118696f, -1.23500824f, 0.76678628f, -3.20508122f, -0.24704689f, 0.49019828f, -1.20862615f, -0.03778638f, -0.07273687f, -0.11517122f, -1.75857520f, -1.64188445f, 1.21574795f, 0.57325113f, 1.14370298f, -1.07824504f, 1.70653832f, -0.03700557f, -0.47645858f, 0.11065386f, -1.03143036f, -2.18094873f, -0.94403434f, -0.09335683f, -0.44817665f, 1.39707148f, -1.21947956f, 0.56575936f, -0.69612634f, -1.12361753f, -0.17105591f, 1.15422392f, 0.02840637f, 0.09469353f, -0.52859986f, -2.08487725f, 1.28789508f, -0.03740775f, 0.61196613f, 1.23405397f, 1.56595814f, -0.65800631f, 2.02985072f, -0.69446486f, -0.88443804f, -0.23448054f, -0.43628734f, -0.45888957f, -0.21943338f, 1.78258693f, 1.75214970f, 0.71804136f, 0.49782532f, 0.37886053f, -1.59176385f, -1.74758542f, -0.02820176f, 0.75398153f, 1.00119829f, 0.80881971f, -0.53365272f, -0.22720885f, 0.37476870f, 0.01005529f, -1.23421800f, -0.13431595f, -1.01843679f, 1.87386346f, -1.68539488f, -1.04942071f, -0.77322137f, 0.53964764f, 0.29278332f, -0.58299130f, -1.56022692f, -0.79441273f, 0.49289709f, 0.44112054f, 1.07305002f, 0.54899335f, 1.13781393f, 0.77809113f, 0.81795985f, 0.16576190f, 0.32552773f, -0.20250474f, 1.46543837f, 0.12731771f, 0.21013761f, -1.34241438f, 0.44267517f, 0.93246883f, 0.08808212f, 0.92653406f, -1.21083558f, 0.17247954f, -0.70557106f, 0.04630012f, 0.48834828f, 0.89634645f, 0.46683592f, -0.29553145f, 0.46363977f, -0.48971879f, -0.88603491f, -0.12333342f, 0.37073737f, 0.92061806f, 0.54675460f, -0.14716248f, 0.75578392f, -0.98173791f, -1.15983224f, -0.58713156f, 0.07950903f, -0.59016788f, 0.41622928f, -0.32474482f, 0.42086437f, 0.23061797f, 0.62596649f, -0.22615278f, -2.14721417f, 1.01685894f, -0.25976995f, 0.00739352f, -1.31597066f, 0.39005190f, -1.09549701f, 1.68375242f, 0.43331525f, -0.37124026f, 0.22255214f, 0.59654880f, -0.73840386f, -1.20048976f, 0.12226126f, 0.12997478f, 1.04826224f, 0.03894836f, -0.36289826f, 1.14466560f, -1.18198848f, -0.03713558f, 0.67677927f, -0.42329931f, -0.89409167f, -0.77874780f, 0.58438253f, -0.35176343f, -1.53329861f, -0.02995299f, -0.40145162f, -1.51052392f, 0.09194464f, -1.13275242f, -0.61983156f, -0.40004560f, -0.19893464f, 0.22134103f, -0.03903082f, 1.14894116f, -0.03476744f, 0.22520730f, -0.55851930f, 0.76650429f, -0.57863152f, -1.34161711f, -0.31498179f, -1.19411755f, 1.70044947f, -0.17428267f, -0.35983825f, -0.42613637f, 0.58165723f, -0.77866900f, -1.59727287f, -0.61723864f, 1.51078022f, 0.32971445f, -0.86441469f, 0.60552609f, 0.00208178f, -0.47096625f, -1.10479307f, -1.21652532f, -0.08211990f, -1.43739200f, -1.31684434f, 0.43312529f, -0.76822090f, 1.88128507f, -0.02179282f, 1.04971325f, -1.55004108f, 1.25337446f, 0.11203052f, -1.16048300f, 1.59467411f, -1.29469275f, 1.14019871f, 1.20021439f, 1.84098923f, 0.05004879f, 0.73529941f, 2.05272865f, -0.13080600f, -0.08436690f, -1.17919350f, -0.66256678f, -0.36727047f, 0.73840511f, 1.22293818f, -0.00206342f, -0.29839504f, -0.00618613f, 1.04213119f, 1.21176076f, -0.62886089f, -0.02589060f, 0.96009409f, -0.64478731f, -1.16516542f, 0.57528079f, 1.04294407f, -0.09774588f, 0.45935291f, 1.03263175f, 1.00633478f, -1.82209253f, -0.18035053f, -0.28302726f, -0.83813244f, 0.57593471f, -0.03807700f, 1.60498738f, 0.16530658f, -1.43083501f, 2.10824299f, 0.30279446f, -0.03961089f, -0.38900724f, 1.31272805f, -0.56575215f, 0.57970244f, -0.48305038f, 1.34114623f, 0.21859215f, 0.66399640f, -1.52087069f, -1.30717897f, 0.14394683f, 0.97648209f, -0.71372712f, -1.22574198f, -0.27702177f, 0.04041927f, 0.02442212f, 2.19617033f, -0.48566443f, 0.81463927f, 0.20383844f, 1.17562282f, -0.33829874f, -0.42141283f, -0.96415234f, -2.39141965f, -1.04285860f, -0.23004992f, 0.41186509f, 0.03811268f, 0.36818987f, -0.71099734f, -0.56749570f, 0.18486284f, -0.44530040f, 2.14008284f, -0.27467576f, 1.70690107f, -1.40462613f, 0.24697532f, -1.31629777f, -2.20674944f, -0.67868507f, -1.15767133f, -0.64391804f, -1.79037917f, 0.58749497f, -1.58303332f, -0.69021022f, 1.64376318f, -0.95393223f, 1.98415601f, -0.10991055f, 0.02474386f, 0.23683345f, -0.63420391f, -0.57991928f, 0.83028817f, -0.40033704f, 0.19212338f, 0.74640590f, 1.10264432f, -1.65286255f, 0.92683482f, -1.42252541f, -0.74605089f, 2.14535880f, 0.12971123f, -0.47971717f, 1.67546797f, 0.42268261f, 0.22648531f, -0.42369929f, 0.77403021f, -1.31818616f, -0.67143595f, -0.04311426f, 1.64128351f, 0.34776631f, -0.39353722f, -0.42765084f, 0.16170517f, -0.54488391f, -0.38428506f, 0.42097485f, -0.55982012f, -1.74543798f, 1.53704774f, 0.43562424f, -0.30395737f, 0.31846946f, 0.39205357f, 0.57386035f, -1.11912560f, -1.39164317f, -1.04337609f, 0.31629622f, 1.51927638f, 0.88745505f, -0.40445471f, 0.25783861f, 1.88646257f, 0.36509129f, -1.13266826f, -0.45394278f, -0.48400903f, -1.22332740f, 0.38626808f, -1.10049105f, 0.84138852f, 1.27863181f, 0.53942156f, -0.67743856f, -0.03896645f, 1.70393491f, 0.60997570f, 0.43368068f, -0.13338457f, -0.18920666f, -0.29583672f, -1.40738738f, 1.03876019f, 1.71253765f, 2.12821221f, -0.96092403f, 0.93841934f, -0.79030478f, 1.36427641f, -1.39196694f, 0.08514920f, 0.16223004f, 0.71259701f, 0.20150672f, 0.25068361f, -0.99952722f, 1.80129099f, -1.28586197f, -0.64957166f, -0.94813949f, -0.40161121f, 0.31977695f, 0.54932386f, -0.67757767f, 1.88086259f, 0.92337233f, -1.64887333f, 0.44333732f, -0.19468001f, 0.12977587f, 0.21171951f, 0.27679422f, 0.49134475f, -1.44429457f, 1.25617445f, 0.39978400f, 0.99869555f, -1.61617446f, 1.61177349f, 0.70243025f, -0.95748568f, -0.61795151f, -0.77302909f, 0.72967088f, 0.81964350f, -0.71813750f, 0.90140164f, -1.45950246f, -0.79972702f, 0.40875742f, 0.00152073f, -1.74491429f, 1.53776145f, 0.75769204f, -0.22075878f, -0.58385569f, 2.18884754f, 0.33597681f, -1.66265559f, 1.03805876f, -1.55245185f, -0.03582226f, -1.94542754f, -0.76081425f, -0.50471377f, 1.35763168f, -0.39631784f, -0.17134467f, -0.82220149f, -0.41021580f, -0.00940776f, -0.80176353f, -0.19816744f, 1.22061026f, -0.14486519f, -0.71727395f, -0.65721530f, 0.47020102f, -0.70403302f, -0.94795334f, 1.79884899f, 0.07779162f, -1.50615680f, 0.04140327f, -0.22001404f, 0.63735324f, 0.79237640f, -2.25412822f, -0.52519119f, -0.87280381f, -0.07100742f, -0.94734806f, -0.12286110f, -0.13623615f, -0.42595413f, 0.17547913f, -0.81707209f, 0.36855817f, -1.68186557f, 0.19312963f, -0.66249490f, -0.98283452f, -0.33314428f, 0.40918943f, 0.88268638f, -0.05390308f, -0.22440539f, -0.15879378f, -0.34859571f, -0.01013108f, -0.30005428f, -1.19408464f, 0.21789688f, -1.07769871f, 0.81475031f, -0.69555300f, 2.35201311f, -0.40362412f, 0.93497628f, 1.13343573f, 0.92343372f, 0.26987928f, 0.46123627f, 0.22577702f, 1.26289701f, -0.45956740f, 0.55994868f, -0.58410591f, 0.13304594f, -0.25806463f, 0.49044946f, -0.82065403f, -3.06672239f, -0.27774641f, 0.68504512f, -0.21386372f, 1.11427057f, -0.73201770f, 0.51655543f, 1.77261138f, 0.72081727f, 0.11116749f, 0.16637769f, -0.74987584f, 0.66579849f, -0.75808716f, 0.20678560f, -0.67698354f, -0.82141948f, 0.61008269f, 0.66520184f, 0.44894725f, 0.73015076f, -1.52517414f, 0.11714164f, 1.90452611f, -1.30355322f, 0.12144456f, 1.18547559f, -0.07349755f, -2.28061509f, 0.83522540f, 0.78438890f, 2.19334102f, 0.90305614f, -0.59345531f, 0.77925014f, 1.32338643f, 0.14068902f, 1.19032264f, 0.20666829f, -0.76595837f, 0.74967057f, 2.86965609f, 0.55690205f, -1.72530472f, -0.83317834f, -0.85842621f, -0.29678273f, 1.80955839f, -0.70496303f, 1.19106734f, -0.92985237f, -1.00617313f, -0.56049556f, -0.29382578f, -2.04022193f, -1.95356870f, -0.42553005f, -0.33369407f, 1.02115977f, -1.45769477f, -0.67720300f, 0.53819913f, 1.57643425f, -0.47015440f, -1.47861958f, -0.00545934f, -0.97836047f, 0.42680529f, 1.56110144f, -1.49487829f, -0.65198445f, 0.22720462f, 1.83036661f, -0.47099793f, -0.09915133f, 0.14923312f, -1.16313052f, 0.67798084f, -1.63665557f, -0.38220280f, 0.01719763f, 0.30041245f, 0.43148938f, -0.44021657f, -1.25734651f, 0.02465564f, -1.00845659f, -0.28574651f, 0.01367745f, 0.77253437f, -0.99399441f, 0.61445391f, 0.18343423f, -0.50997210f, 0.41359940f, 0.77279282f, 0.83511519f, 0.27929801f, 0.70800692f, -0.20278299f, 1.57884383f, 0.22650529f, 0.43347472f, 0.74003208f, -0.71401161f, -0.69829476f, -1.56766701f, -0.99254119f, 1.27301061f, 2.73726511f, 0.66089469f, -1.95778012f, -1.24642098f, -0.63579029f, -1.63168180f, -0.66980726f, 0.81933254f, 0.61866677f, 1.40594471f, 0.05158535f, 0.00196500f, -0.24592508f, -0.50780547f, -0.83905292f, -0.10748957f, 0.04490763f, 0.27769178f, -0.23227681f, 0.82108080f, 0.03562285f, 0.95483875f, -1.49897683f, 0.67809856f, 0.35497451f, -0.44021592f, -1.67361462f, -0.88895375f, 1.44293678f, -0.85046643f, -0.46437624f, -1.87252641f, 0.26775804f, -0.24535774f, 0.73365933f, 0.52253938f, 0.27947086f, -0.58796054f, 0.59045380f, 1.93476331f, -0.46775359f, 0.25238225f, -1.26601815f, -0.13324316f, -0.71454948f, -0.21610366f, -1.49586582f, 1.04903507f, 0.22208478f, 0.25512528f, -0.46157327f, -0.41319233f, -0.63846964f, -0.25100923f, 0.81277549f, -0.26959971f, 0.88737756f, 1.24578953f, -0.91121447f, -1.05756927f, 0.44390878f, 0.16672316f, -1.22941923f, 0.89547867f, -1.50212002f, -1.69620168f, 0.53339505f, -0.23656729f, -1.69879091f, 0.01510374f, 0.08315694f, -0.73196459f, -1.60263407f, -1.07601058f, -0.76389569f, -1.65307498f, -0.61484390f, -0.43546933f, 0.71318507f, -0.16273083f, 0.64122051f, -0.15406294f, 1.17673671f, -0.91240519f, 0.71091145f, 2.40497613f, 1.26343656f, 0.71469337f, 0.20705548f, 0.81776261f, 0.36253929f, -1.92106628f, -0.09300470f, -0.36648872f, 1.27732766f, -0.39180157f, -0.61186749f, -1.03455031f, -0.25079829f, -0.61479062f, -1.07094336f, 0.82218504f, 0.89934880f, 0.41308978f, -0.59968555f, 0.37682834f, -1.77388155f, 0.00294951f, -0.66145372f, -0.50789726f, -0.85123241f, -0.89909405f, -1.89454281f, -0.56692821f, 1.52272677f, -0.11961794f, 0.27843913f, -0.60582250f, 1.01871169f, -0.36098275f, -0.12242325f, -0.67375034f, -0.11204147f, -2.62773919f, -0.95901299f, 0.14040214f, 1.32364666f, -1.35099924f, -0.11077739f, -0.79319423f, 0.75949597f, -0.25485823f, -0.90959758f, -0.42373934f, -1.29850340f, 0.85699379f, -1.11882365f, 0.63470817f, 0.49696380f, -0.07983235f, -0.23903450f, -0.22618714f, -0.12117998f, -0.09442677f, 1.55589819f, -0.11996678f, -1.72700179f, 0.54683149f, -0.40804827f, -0.50099218f, 0.34596699f, -1.81841791f, 0.06385052f, 0.84428120f, 0.69901514f, 1.94559097f, 0.43251973f, 0.16794942f, 1.82829034f, 1.70959795f, 0.36130908f, -0.94608402f, -0.53498030f, 0.47781768f, -0.24203247f, 1.25065851f, 0.51788396f, -2.09381890f, 0.72973937f, 0.03281829f, 0.58632666f, 1.85737121f, -0.49569523f, 0.45921183f, 1.87173629f, 0.22803484f, 1.66433418f, -1.05872321f, -1.13663685f, 0.12397861f, -0.65112090f, 0.98152941f, 0.83739656f, -0.18783289f, 1.84249437f, -0.90706986f, -0.80824369f, -1.23854923f, -0.86488134f, -1.02627063f, 0.10976455f, -0.61403006f, 1.27554715f, 0.14653525f, -0.03953953f, -0.08512071f, -1.30043304f, -0.02566035f, 0.12054887f, 0.00282162f, 0.48921332f, -1.74398839f, 1.44554436f, -1.35854721f, 0.69256759f, 0.34101671f, 2.50045252f, 0.49121150f, -0.27115449f, 0.93974596f, 0.26258010f, 0.27151433f, -0.87214381f, -0.92580765f, -1.03269923f, 0.20615758f, -0.37822601f, 0.58983004f, 0.16426525f, 0.68218285f, 1.98158526f, 0.47492698f, 0.54224718f, 1.28722692f, -1.76915324f, -1.11240053f, 0.77428484f, 0.27184650f, 2.22473478f, -0.05574624f, 0.39976570f, -0.43911108f, 0.52805597f, 0.17340177f, 1.36057591f, -0.35004014f, 1.72787797f, 0.68357420f, 1.25532615f, -0.56752264f, 0.51840127f, -0.21237844f, -0.58821255f, -0.85278064f, 1.90179110f, -0.67447448f, -0.36831430f, -0.22930753f, 0.98231596f, -0.07011599f, -0.08560387f, 0.05998110f, -0.02481356f, -0.57335132f, -0.44288307f, -0.24468307f, 0.53321087f, 1.19609559f, 0.10664973f, 0.24379487f, 0.93687552f, 0.93615580f, 1.74319768f, -0.68310338f, 1.32163060f, 0.61918712f, -0.76501870f, -0.54549301f, 1.74077415f, -0.69977754f, -0.66880983f, -1.15981388f, 0.81571609f, 0.53788543f, 0.47898352f, -0.02484704f, -1.64646924f, -0.69822907f, 0.27020717f, 0.05027051f, 1.75149667f, 0.01548872f, 0.32615909f, 2.55151844f, -1.29172051f, -0.36133784f, 0.98637396f, 0.14009331f, -0.50038946f, -0.92230296f, 0.17307127f, 1.05361068f, -1.46784890f, 2.38960409f, 1.19413340f, -1.33349669f, 1.59141159f, -0.71811068f, 1.22429430f, 1.26947939f, 1.08177102f, -1.18138707f, -0.72775704f, 0.17282635f, -0.40554270f, -0.40341887f, 0.46564049f, -1.02069795f, -0.07653128f, -0.13979210f, -0.31195050f, -1.72042310f, 1.37131393f, 0.63849634f, 0.75561279f, 1.81152904f, 0.26686314f, 1.32796574f, 0.56100166f, 0.70058894f, -0.88962644f, -0.04360984f, -0.88249093f, 0.24311203f, 0.50410056f, -2.22567797f, 0.94520348f, -2.12467694f, 0.47282359f, -0.71379906f, -0.09857135f, 0.62374717f, 1.37182784f, 0.73380554f, 0.59745449f, 2.80427694f, 0.67253572f, 1.65335357f, 1.69891667f, 1.34585941f, -0.79989213f, 1.44980943f, -0.52013642f, -0.46971673f, -1.50070012f, -0.25687039f, -0.56916732f, 0.71065760f, -1.31996286f, 0.96031237f, 0.13929774f, 1.49679291f, -0.05966444f, -0.58674580f, -0.08278833f, -0.93390942f, 0.42415768f, -1.77889526f, 0.75336021f, -0.72699982f, -0.82880586f, 0.63955617f, 0.42771208f, -0.42366457f, -0.91581815f, 0.94750947f, 0.43123913f, -0.99053741f, 0.70470595f, -1.16662264f, 1.14847183f, -0.83885664f, 0.46714026f, -2.27748466f, -1.23656678f, 0.14695056f, -0.33159894f, -0.52553117f, -0.04391259f, -0.29630372f, 0.25949728f, 0.96991086f, -0.37714824f, -0.28251833f, 0.16106486f, 1.38844633f, -0.18713553f, -1.30708838f, 0.48490265f, 0.29553881f, -0.45505449f, 0.83341682f, 0.87346369f, -0.63516861f, 0.66063565f, 0.93892503f, -2.73996735f, -0.81515318f, -0.91458052f, 0.00978268f, 0.43472794f, -0.08090764f, 1.37249672f, 0.76722521f, -1.19154143f, 0.22046764f, 0.34916410f, 0.51383299f, -0.56379753f, -2.49949312f, -0.74207872f, -0.68400806f, -0.09663232f, -0.07199454f, -1.05562651f, -0.75028551f, -0.87253797f, 0.69039482f, 0.45923674f, -1.27515161f, -0.04555376f, -1.41501272f, -0.83773375f, -0.74807298f, 1.36646152f, 0.06317432f, -1.32559633f, 1.89092779f, 1.24883330f, -1.03608561f, 1.08677161f, -0.99629849f, -0.69947034f, -0.85716367f, -0.07947286f, -0.25485426f, -0.19732477f, 1.64581251f, 1.04618108f, 1.87186897f, -0.18198362f, -0.83807969f, 0.70462501f, -3.18930101f, 0.74610996f, -0.60935193f, -0.49383929f, -2.88986492f, 0.51707613f, 1.04620326f, 1.09837818f, -1.19840038f, -0.10391295f, -0.20789115f, -1.51052022f, -0.31087330f, 0.22411564f, -1.30506921f, -1.52000105f, -1.51593041f, 1.04321992f, 0.97611690f, 0.90424490f, 1.83324766f, -0.08682299f, 0.47035542f, 1.70865905f, -0.31108001f, 0.04115159f, -1.36352801f, -0.90797836f, 0.32128647f, 0.66191489f, 0.08681208f, 0.14993365f, 0.47110486f, -0.31522670f, -0.38906571f, -0.08876022f, -0.13106902f, 2.25685239f, -0.62211353f, -1.68553007f, -0.23707703f, 0.69236159f, -0.46686995f, -0.27520603f, 0.26619941f, 1.48525345f, 1.61278927f, 0.49452963f, 1.20846486f, -1.11853909f, -0.30010033f, -0.75471467f, -1.69959772f, -0.52042168f, -0.43881389f, -1.45240712f, 1.02122891f, 1.73639011f, -0.03813924f, -0.22239220f, 0.15797073f, -0.64418089f, -0.60228932f, -0.83248150f, -0.02042520f, 0.38137484f, 0.86056453f, 0.06410559f, -0.62785137f, -0.49916875f, -2.53796315f, -0.79168582f, -0.69197005f, -0.77175534f, -0.28669405f, -0.79764080f, 0.97218460f, -0.10351621f, -0.52759898f, 1.02840185f, 1.16363287f, 0.08351815f, -0.61088538f, 0.59944046f, 1.54409397f, -1.39842033f, 0.27917057f, -0.27146137f, 1.46310735f, 0.03626106f, 0.15038440f, -0.07894899f, -1.42527366f, 1.69641745f, 1.48384345f, -0.43328866f, -0.54252565f, -0.94416499f, 1.54436302f, -0.81367069f, -1.67925239f, -0.17525831f, 0.27891046f, -0.69066733f, 0.89911050f, 0.11606655f, 0.67450327f, 0.41538724f, 0.90886223f, 1.19786549f, 0.85810721f, 1.32862210f, -0.83469814f, -1.09682298f, 0.88092703f, -0.97478902f, -0.11664717f, -0.07929394f, -0.69581884f, -0.16928329f, -0.70731819f, -0.40485084f, -0.28954300f, 0.52882415f, 0.38769314f, -1.38704026f, 1.15099049f, -0.43566978f, 0.34459323f, 0.49520254f, 1.11130333f, 0.28783718f, -0.53783375f, -1.63577271f, 1.02222812f, 0.86302060f, 0.48346213f, 0.46627176f, -1.30133855f, -1.48477137f, 0.31219670f, -1.21498191f, 0.89838904f, 0.87186617f, -0.39968935f, 0.34930915f, -0.32909471f, -1.39364409f, 2.13006306f, 0.33270469f, 0.00215986f, 0.97776711f, 0.24908836f, 1.56164885f, 0.45157790f, -1.55970144f, 0.27677536f, 0.07662498f, -0.08262251f, -0.17658773f, 0.65820259f, 2.01052690f, -1.71946216f, 0.84686053f, -1.23594892f, 1.40792072f, -1.47772563f, -0.36132276f, -0.50405115f, 0.09009213f, 0.81659186f, 1.85574234f, -0.64974433f, 0.63352364f, 1.01766217f, -1.54804432f, -0.42570522f, -0.24763709f, 0.72822112f, -0.93733686f, 0.68087620f, -1.40644944f, 0.48672482f, 0.09725539f, -0.64416331f, -0.95747960f, 0.36771363f, 0.39155054f, -0.71790671f, -2.17222738f, -0.08655047f, -0.97842115f, -0.22991380f, 0.52029115f, -1.42072022f, 0.29576331f, 0.32391560f, -1.00823236f, 1.67909145f, 1.16841447f, -0.32307062f, 0.15756166f, -0.97590631f, -0.39429301f, -0.03583352f, 0.17554663f, 0.57961231f, -0.46873134f, -0.23343173f, -0.85060924f, 1.71745574f, -0.04658702f, 0.63088381f, -0.67581934f, -1.53171062f, -1.58800113f, -1.17987096f, -1.16737640f, -0.87544650f, -1.17138922f, 0.38979119f, -2.39369726f, -1.34747124f, 0.58450359f, 0.87791806f, -0.04459394f, 0.97995293f, -0.10354915f, 0.65324986f, -0.17833626f, -0.85849386f, -0.42063358f, 0.19708554f, 0.10255250f, -0.59539181f, 0.86194044f, 1.68610668f, 0.55275291f, -0.43127069f, -0.04218780f, -0.08466262f, 0.31236625f, -0.92824298f, -0.09879152f, 0.32358822f, 1.04045570f, 0.35617545f, 0.09059231f, 1.19069445f, 1.96978688f, 0.63561743f, 0.15030998f, -0.29879019f, 0.22774190f, -1.01608860f, 1.03605175f, 0.47804731f, -0.30450734f, -0.61382371f, 0.45390254f, -1.93547988f, 2.01267338f, 0.52447683f, 0.18379784f, 1.11913633f, -1.24273467f, 0.15803322f, 1.72184098f, -0.79349059f, 0.10258614f, -1.53445125f, 0.02630571f, 0.81649125f, 0.91089755f, -1.12968338f, 1.04016411f, 0.28999722f, 0.74863863f, -0.61388236f, 0.01665530f, 1.43592548f, 0.68138391f, 0.11963340f, -1.26123953f, 1.36340797f, 0.25696915f, -0.58877039f, 1.42209792f, 0.55563360f, -1.33329606f, 1.84695840f, 0.88433737f, 1.04359078f, 0.18906727f, -0.03448994f, 1.17944050f, 0.86783957f, 0.44934425f, -0.77892244f, -1.76232874f, -1.01689589f, 0.78943914f, 0.92141974f, -1.00187087f, -0.13809921f, -0.90222073f, 1.10094714f, -0.13657950f, -0.44349849f, -1.61441302f, 1.05724919f, 1.50337231f, -0.05785890f, -0.76958144f, -0.51498759f, 0.69227600f, -0.37975949f, 1.31949317f, 0.82049531f, 0.32868597f, -0.31557772f, -0.75534385f, 1.27303052f, 0.43453619f, 0.11296938f, 1.18182182f, 2.23387384f, -0.86412978f, -0.01599468f, -0.70869064f, -0.09221385f, -1.23729551f, 0.79490280f, 0.03522846f, -0.95069039f, -1.73461652f, 0.72329187f, 1.40385795f, -0.11585230f, -0.78033113f, 0.07491048f, -1.12873089f, 0.18476245f, 0.57568848f, -0.28792691f, 1.35411644f, -0.76956165f, 0.29571572f, 1.03178787f, -0.38780826f, 0.31680650f, 0.69368076f, -1.23856580f, -0.49848995f, 0.14766994f, 1.02625990f, 3.03858209f, -0.51030380f, 0.96796870f, 1.35078156f, -1.07729447f, 0.84322494f, 0.54886484f, 1.31453705f, -0.45792100f, 0.31196272f, -0.15701357f, 0.83586836f, -0.74952888f, -1.17432022f, -0.31002575f, -1.02149463f, -0.36117774f, -1.22079086f, 0.03532525f, 0.00555908f, -0.45891216f, 0.29636297f, -0.68272704f, 0.41257843f, 0.37988129f, 0.01747893f, 0.82739186f, 1.52292180f, -0.79456621f, 2.20275712f, 2.13212132f, -0.81393015f, -1.15712392f, 0.22488308f, 0.62776327f, -0.85444915f, 0.44017896f, 0.05863331f, -0.83198178f, 0.93063420f, -0.16121253f, 0.12382501f, -0.37826315f, 0.93118382f, 0.19507533f, -0.58595538f, 1.46994352f, 0.13170272f, -0.70031989f, -0.12820166f, 0.30487457f, 0.84148771f, -0.68807501f, 0.21187615f, -0.67030680f, -1.79136002f, 0.70810199f, -1.20959783f, -0.08468831f, -0.06317700f, 1.35527098f, -0.47018668f, -0.91693246f, 0.14818805f, -0.05405350f, 1.16875637f, -0.17363262f, -1.61833882f, -0.32934523f, -0.38346377f, -0.62702698f, 0.34135151f, 0.48015586f, -0.65263331f, -0.04689486f, 0.01156854f, 0.37580970f, -0.16174591f, 0.59627324f, 0.24351901f, -0.87983090f, 1.57049024f, 1.25836349f, -0.41464049f, -0.62279183f, 0.09693756f, -0.23850618f, -0.49007827f, 0.22298151f, 0.10914832f, -0.35192192f, -1.27221346f, 1.10203624f, -0.86399704f, -0.47319838f, -0.77105570f, -1.68624854f, 0.81198281f, 0.82534081f, 0.75654501f, 1.47631240f, -0.61000234f, -0.58933264f, 0.54822850f, -1.22829592f, 0.11107657f, 0.56449169f, 1.50693524f, -0.59280968f, -0.64286685f, -0.20120731f, 0.27184448f, 1.55500400f, -0.48919386f, 1.04044867f, -0.87048137f, -0.40569979f, 0.21908638f, -0.51829034f, -1.48748124f, 0.02990401f, 1.83462536f, 0.29885170f, 1.32370698f, -1.30129600f, 2.43271399f, 0.22967771f, -1.13014007f, 0.95529765f, -0.83325785f, 0.43633386f, 0.85774118f, 0.78160155f, 0.58583075f, 1.18906367f, -1.54354560f, -0.68320692f, 0.01900371f, -0.79777133f, 0.12851712f, 1.10176420f, 0.79418170f, -1.41154039f, 0.36929929f, 1.12176800f, 1.23849642f, -0.89377707f, 1.01390159f, -0.50889206f, -1.12554002f, 0.17932732f, 0.48949540f, -0.54235244f, -0.28146735f, -1.39125514f, 0.13309635f, -1.12864995f, -1.29901242f, -0.04266220f, -1.98028529f, -1.34869373f, 0.00038156f, -0.92473024f, 1.48010647f, -0.02754467f, -0.26030368f, 0.93083733f, 0.27946711f, 0.64052200f, -0.04220961f, 1.25002527f, -1.07923257f, 0.19048618f, 0.08900311f, -0.40813437f, -0.73068553f, 0.52122378f, 0.68990833f, -0.38749605f, -1.09269309f, -1.63480806f, 1.01789618f, -0.61596102f, 0.81049860f, 1.30838764f, -1.49213874f, -0.77916288f, -0.72660202f, -0.92013240f, -1.61726642f, -0.11527207f, 0.35143322f, -1.11646879f, -1.45525432f, -0.82892823f, 0.15512508f, 1.01891017f, 1.40162635f, 1.02494884f, 0.33882582f, -0.78747398f, -0.26009330f, -0.38519114f, 0.79247451f, 0.02065756f, -0.48030257f, 1.01167107f, -1.74057114f, -0.84549171f, -0.15337363f, -1.92544484f, 1.01270044f, 0.00762185f, -0.16405612f, 1.61778915f, 0.93316060f, -0.68960994f, -1.13214970f, -0.94695878f, -0.28418848f, 0.17102109f, -0.08787476f, -1.83799696f, -0.13761258f, -0.18652774f, 1.46456254f, 0.34169790f, -0.40697145f, 1.49663997f, -0.99555492f, -0.67775637f, -0.51951116f, 1.35157657f, -0.27099034f, -0.46987835f, 2.28101230f, 0.59104478f, 0.75010139f, 1.01472175f, 0.25741309f, -0.56074983f, 1.12267506f, 0.35336846f, 0.61733276f, -1.63976014f, -0.17700450f, -0.25093642f, -0.75599891f, 2.10956192f, 0.95155340f, 0.72049862f, 0.50492924f, 0.62067389f, 2.08688402f, -0.73604703f, 0.63383341f, -0.53528428f, -2.11538506f, -0.98173052f, 0.59560484f, -0.26205051f, -0.91948050f, 0.00593397f, -0.11734286f, -1.41261208f, -0.83611172f, -0.27682739f, -0.20619918f, -0.36557615f, 0.77194935f, 1.67695415f, -1.39265156f, 0.04892010f, -0.37773246f, 0.16124558f, -0.18348448f, -1.38248885f, 0.58459854f, 0.65064198f, 1.11349559f, 0.36708066f, -0.15471332f, 0.14208725f, -2.06860566f, 0.29629150f, 0.93084633f, -0.47215626f, 0.60208917f, 0.95415461f, 1.03390312f, -0.03639749f, -0.23988228f, 1.27037442f, 0.95133096f, 0.33187470f, -0.34527761f, 0.22134073f, 1.01799667f, -0.81475645f, -1.18869019f, 0.23314142f, 0.25180560f, -1.23762786f, 1.25283313f, 0.16980635f, 0.40740708f, 0.59256923f, 0.16274920f, -0.69713289f, -0.16444311f, -2.41602516f, 0.37952334f, -0.05604568f, -0.23772651f, 0.20581599f, -0.54303211f, 1.71877348f, 0.83602583f, -0.32586128f, 0.73609394f, -1.73640239f, 0.07249248f, 0.31248692f, 1.77627432f, 0.97660398f, -0.42095289f, -0.18750280f, -0.84246057f, 0.29762223f, 1.87054563f, -1.46980762f, -0.45306337f, 1.52366042f, 1.39061129f, -0.04980387f, -0.55382830f, -0.96987218f, -0.06910808f, -0.41276473f, -0.83891344f, -0.92597574f, 0.60252470f, 0.21938549f, -0.04451685f, -1.00330937f, -0.36955237f, -1.52876902f, 0.27296364f, -1.96721256f, 0.05291027f, -0.91540521f, 0.48990685f, -1.99560380f, -0.68551093f, -0.14532298f, -1.56881595f, -0.08319287f, 0.31003201f, -1.42829597f, -0.61810297f, -0.03581250f, 0.77747720f, 1.25297558f, -1.36239243f, -1.13274276f, -0.35045877f, -2.34157228f, 0.04515179f, -0.83044821f, 1.81353962f, -1.36855912f, 0.39704823f, 0.16665934f, -0.16654585f, 1.17806077f, 1.00086153f, -1.25474250f, -1.46876431f, 1.18021631f, -0.32257929f, 2.12062597f, 0.86819613f, -1.18048275f, -1.69747460f, -0.74092305f, 0.05086798f, 1.15339577f, 1.32972670f, 0.27247882f, 0.98499072f, 2.35597157f, 0.30179837f, -0.66633248f, 0.13794266f, -0.22753908f, -0.22868259f, -1.81792033f, 0.50151759f, -0.79408127f, -1.05343878f, 0.45727381f, 0.84800923f, -1.73605800f, -0.02032863f, 1.82778001f, 1.41025102f, -0.81715560f, 0.25888795f, -0.25075480f, 0.66256499f, 0.11993053f, 1.81336939f, -0.06345166f, -1.49658346f, 0.07531686f, 0.96972889f, 0.87405980f, 0.75830793f, -0.13497087f, -2.45855975f, -0.65984958f, 0.93919373f, -0.97305542f, 0.73477978f, 1.04337513f, -1.22712576f, -0.46385625f, -1.20876372f, -0.82760453f, 0.01455977f, -1.05089867f, -0.02801843f, 0.60899758f, -0.82052249f, -1.48932517f, -0.98073828f, -0.19311285f, -0.25602359f, 0.50351876f, -1.24557400f, -0.82138073f, -1.45966852f, 0.44991320f, -0.75550151f, -0.98550314f, -1.21418869f, -1.15771639f, -1.72192061f, -0.39616469f, -0.55566746f, -1.31880891f, -0.08843257f, 1.00422776f, 0.35846478f, 0.46060917f, 0.77326930f, 1.60129988f, -1.85124147f, -0.30582917f, 1.30227256f, 1.81890345f, -0.44084981f, 0.25315762f, 0.70259613f, -0.94882858f, 1.97040296f, 0.71473581f, -0.68193883f, -0.36290962f, 1.16348684f, 0.15418798f, 1.07806778f, 0.40554729f, 0.10280909f, -1.06474805f, 0.64398485f, -0.63568884f, -0.06108581f, -1.03290677f, 1.02834034f, 1.15284693f, 0.14046004f, 1.86630619f, 0.46804786f, -0.68397558f, 1.60733378f, -1.64890087f, -1.03819239f, -1.19212389f, -0.78382361f, 0.03925850f, 1.52259934f, 0.09540676f, -0.21220762f, 0.55955195f, -0.39845437f, -2.14541650f, 0.49337825f, -0.68574250f, 0.74040270f, 0.50783634f, -1.60461199f, -1.26806450f, -0.12652303f, -0.83992827f, -0.15524681f, 0.40098447f, 0.23392735f, -0.23262636f, 0.06525709f, -0.35994548f, -1.08432877f, -0.21395946f, -0.78357452f, -0.57157278f, 0.71407390f, 0.86596155f, -1.13723528f, 0.13460183f, -1.20881450f, 0.71018457f, 0.68943661f, -0.70428050f, 0.64600736f, 0.01990297f, -0.10575775f, -0.80263519f, 0.10618331f, 0.08865548f, 1.51651669f, 0.60851854f, 1.15161908f, 1.04919207f, 1.18359745f, -0.04352076f, -0.83643389f, -0.07922365f, 0.10597949f, -1.34984851f, -1.91319740f, 0.71585363f, -2.10845160f, 0.64385056f, -0.54551518f, -1.02039802f, -1.62510490f, 1.65401149f, -0.42711899f, 0.07970079f, -0.21404363f, 0.30498922f, 1.07942021f, 0.63995659f, -1.82114816f, 0.56396323f, 1.07084870f, -2.00350380f, 0.53339815f, 0.18500003f, 1.15034151f, -0.21436051f, -0.99986565f, -0.58812016f, -0.07247020f, 0.78910017f, 0.48839527f, 0.98795873f, 0.10357288f, -0.05604928f, 0.38977858f, 0.73745090f, 1.40838420f, 0.25967824f, 0.23588051f, -0.03451392f, 1.04897523f, -1.77121758f, 2.35625434f, -0.67086869f, -0.84005541f, -0.85940343f, -1.04449213f, -0.65917015f, -0.78713167f, -0.95910054f, 0.38597879f, -0.31879017f, -0.86260867f, -1.08593106f, 0.02802678f, 0.99484950f, -0.55113328f, 2.60936737f, -0.03388772f, -0.47583574f, -0.14021793f, 0.99019170f, -1.22431207f, 0.78734446f, -1.77037835f, 0.15018673f, 0.36423206f, 1.36447549f, -1.61007094f, 0.51875496f, -1.60788095f, -1.73557448f, -0.41414359f, -0.93710536f, 0.38715765f, 0.04243837f, -1.59682858f, -1.10728157f, 1.88292623f, -1.01428258f, 0.01074958f, -1.88169158f, -0.31616244f, 0.45334938f, 1.12449574f, -1.16699445f, -1.59505820f, 0.04126552f, -0.89016622f, 0.45838884f, 0.71463561f, 0.14563711f, 0.30694655f, 0.67193079f, 0.61429602f, 1.00201404f, -0.49295208f, 0.05997690f, 0.99491668f, -0.73801446f, -1.17185295f, 0.94778723f, 0.36106884f, -0.43561545f, 0.04102699f, 0.52626407f, 0.08442099f, -1.57626402f, 1.56855237f, -1.65396678f, 1.74014664f, -0.38219589f, 0.39305371f, -0.31705827f, -1.15742850f, 0.11669596f, 0.54043210f, -0.52270615f, -0.13375773f, 0.68094701f, -1.84134769f, -1.49383473f, 0.14632171f, -0.54607725f, -1.20867658f, -1.28439069f, -1.81734920f, 1.54257309f, 0.78347659f, -0.24049839f, 1.69973648f, 0.99825776f, 0.99971974f, -0.26055810f, 0.34143049f, -0.44862366f, 0.11253342f, -0.60932243f, 0.70383030f, -1.87318194f, 0.21953633f, 0.82791799f, 1.64545465f, -0.42693698f, -0.64897031f, -0.97996652f, -1.06616282f, 0.52939081f, -0.12541170f, -0.57480675f, 0.73600835f, 0.35711968f, -0.03528263f, 0.79997194f, 0.55742902f, -0.28909785f, 0.64331138f, -1.79893720f, 1.01572442f, 0.27111965f, -0.51778597f, 0.12906317f, 0.76148927f, 1.51315522f, 0.41101140f, 0.38008851f, 0.66759896f, -0.13804778f, 0.64854795f, 1.73474562f, 0.75999504f, -0.73411214f, -0.05406699f, 1.35664344f, -0.25298578f, -0.12696666f, -0.42628938f, 0.61129904f, 1.55259824f, -0.05820796f, -0.38598019f, -0.87325627f, -0.55066222f, -1.24557889f, -0.26509118f, -0.32103062f, 1.14031804f, -0.75985742f, 0.70659167f, -1.15016067f, 1.24906838f, 0.90396994f, -0.16241251f, 0.43682271f, -1.42695689f, 0.47134697f, -1.66143429f, 0.08698819f, -1.00775325f, -2.24129725f, -1.04226267f, -0.98537570f, -0.89938259f, -1.80710697f, -1.22866321f, 0.78125423f, 1.55150509f, 0.46235040f, 0.18444096f, 0.19313288f, -2.20686269f, -0.40341458f, 0.50321484f, 0.47339424f, -0.81383848f, -0.21972439f, 0.66612029f, 0.60239881f, 1.20443010f, 0.70015103f, 0.30632916f, 0.01489905f, 0.68129027f, -0.89645082f, -2.68969011f, -0.96684915f, 1.66421318f, 0.74333072f, -0.78321886f, 1.60063362f, -1.27524030f, -1.95856726f, 0.47504124f, 0.15398432f, -0.20796098f, -0.13449343f, 0.93458968f, 1.60390890f, 0.21798505f, -0.27035928f, -1.23248971f, -1.25361061f, 1.34666133f, 1.07233441f, 0.88799530f, -1.23687923f, -0.40781614f, -0.11916534f, -0.88050151f, -0.66422415f, -2.61471510f, 0.78276747f, 2.42323995f, -1.70715427f, 0.71550035f, -0.60298312f, 0.70491880f, 0.46175584f, 0.80827898f, -0.45108104f, -0.98219043f, -1.72823501f, 1.73190725f, 0.53906441f, -1.50445580f, -0.59250867f, -0.07239901f, 0.44743437f, -0.13740127f, 1.69935930f, -1.00480616f, -0.58191377f, 0.39853972f, -0.60960841f, -0.45473522f, -0.76396072f, -0.31872150f, 1.74509728f, -0.59950751f, 0.89810580f, -0.81400329f, 1.14280319f, 1.11165059f, -1.31295311f, -1.60784578f, -0.87506992f, -1.13461006f, -2.09486437f, -0.16449419f, -0.37728927f, 0.47595578f, -0.55342919f, -0.17574213f, 2.21499181f, 1.14331865f, -0.14938518f, 0.18935619f, -0.33802557f, 0.52538890f, 0.82673949f, 1.16562462f, 1.24713838f, 0.98890215f, -0.64991701f, 1.49886703f, 1.97769642f, 0.08059916f, -1.60925281f, -1.23822486f, -1.40829837f, 0.51331180f, -0.29928651f, -1.04348791f, -0.39911583f, 0.69380492f, 1.54516888f, 1.22791195f, 2.25008130f, 1.33348894f, -0.21775827f, -0.71937007f, 0.54982573f, 1.70691478f, 0.32459491f, -0.57187974f, -0.21614684f, 1.08274269f, 0.41384646f, 0.24497485f, -1.43703413f, 0.89616930f, 0.82032162f, -0.24598582f, 0.84271127f, -0.81894702f, -0.01828136f, 1.70397091f, 0.39505738f, -0.51221430f, -0.87979966f, 0.10795479f, 0.45194778f, -0.76008922f, 1.23394477f, -0.56798172f, 1.06459570f, -0.44333413f, -2.40399075f, -0.37267187f, 1.42946172f, 0.95734519f, 1.86127949f, -0.15217264f, 1.68742633f, 1.97638428f, -0.44211119f, -0.98393327f, -0.54173928f, -1.72017395f, 0.74697793f, -1.77827263f, -1.92299354f, -0.17189410f, -0.48633271f, -2.21230388f, -0.45906609f, -0.53493047f, 0.37253976f, -0.56951141f, 0.07728028f, 0.03530006f, -1.18123293f, 1.94158125f, -1.55930352f, 0.69334733f, -1.95163214f, -0.95800400f, -0.01804711f, -0.56747472f, -0.99099451f, -1.52853060f, -0.98279524f, -1.67307866f, 0.96121490f, 0.35654056f, 1.74034202f, -1.44633865f, -0.27781928f, 1.79457986f, -0.41029963f, -0.76871634f, 0.36555341f, -0.77664107f, 0.19535238f, -0.76185411f, -0.19828433f, -0.88820636f, 0.63885397f, 0.11346363f, -2.50265074f, 0.16319332f, -1.01288569f, 1.86605489f, 0.89761645f, 1.11795115f, -0.00714116f, -0.89034635f, -0.76447034f, -0.18822117f, -0.48340848f, -0.99788517f, 1.02172959f, -0.39395007f, 0.72566581f, -0.81438208f, -0.71715081f, 0.96243578f, -1.36424279f, -1.13870537f, 1.17602491f, 0.16320205f, 0.71959788f, 1.66669416f, 0.55690295f, -0.28912008f, -1.19219172f, 0.23308393f, -0.37963116f, 0.45347008f, -0.42606446f, 1.30938649f, 1.25128853f, 0.57649273f, 0.34440875f, -0.23893952f, -1.06604803f, 0.31336102f, 0.75727910f, 0.46772480f, -0.37650385f, -0.06036821f, 1.03686309f, 0.46158856f, -1.81028461f, 1.43393028f, 0.85494965f, -2.34685564f, -0.17571987f, -0.45592231f, -1.31190526f, 1.73194158f, -0.11856517f, 0.07041293f, 0.25689471f, -0.56000596f, 2.06649089f, 0.38954756f, 1.36627376f, 0.13905638f, 0.77370811f, 0.43944249f, -0.08798827f, 0.07245751f, -1.30234015f, 0.29710820f, 0.74389762f, 0.11971968f, -0.07381748f, 1.32652700f, 1.34079397f}); auto input2 = NDArrayFactory::create('c', {3, 4, 4, 5}, {0.98114507f, 0.96400015f, 0.58669623f, 0.60073098f, 0.75425418f, 0.44258752f, 0.76373084f, 0.96593234f, 0.34067846f, 0.57962620f, 0.77517051f, 0.97472977f, 0.79237527f, 0.68690428f, 0.21719366f, 0.79959206f, 0.84814187f, 0.22496814f, 0.08646965f, 0.31110474f, 0.79813162f, 0.19661444f, 0.57760099f, 0.72138960f, 0.15244268f, 0.87687051f, 0.11130344f, 0.01087698f, 0.34817841f, 0.54992017f, 0.23443850f, 0.31725614f, 0.59755220f, 0.20364695f, 0.00531392f, 0.23403114f, 0.07442912f, 0.83707647f, 0.89291743f, 0.09044587f, 0.69041462f, 0.29904183f, 0.61904680f, 0.85306847f, 0.34467042f, 0.95839152f, 0.54517124f, 0.29640937f, 0.94855959f, 0.95970016f, 0.94045145f, 0.95510301f, 0.34666505f, 0.34717010f, 0.69245678f, 0.71669175f, 0.59043738f, 0.64924132f, 0.06033522f, 0.60185199f, 0.04690073f, 0.59241154f, 0.40229547f, 0.23002481f, 0.45161195f, 0.73743778f, 0.93209113f, 0.37294358f, 0.50177744f, 0.15072501f, 0.26146917f, 0.05252146f, 0.04758931f, 0.76448288f, 0.85149045f, 0.08840467f, 0.07692576f, 0.33180160f, 0.27241259f, 0.74834620f, 0.56453640f, 0.23057286f, 0.68429752f, 0.11961551f, 0.39045977f, 0.44356094f, 0.77018807f, 0.07984410f, 0.47926806f, 0.26165759f, 0.18606064f, 0.89972877f, 0.17962874f, 0.47273120f, 0.64641705f, 0.61890443f, 0.58730015f, 0.25937832f, 0.35231561f, 0.10243882f, 0.17459193f, 0.95906995f, 0.09227025f, 0.30003223f, 0.41601210f, 0.38269713f, 0.84799751f, 0.59295173f, 0.76277990f, 0.68910424f, 0.37672606f, 0.40675461f, 0.94346058f, 0.91438505f, 0.84728183f, 0.64367667f, 0.74899979f, 0.60570691f, 0.16417363f, 0.68852426f, 0.85486889f, 0.22585792f, 0.86953176f, 0.07465519f, 0.93096301f, 0.38008822f, 0.38752587f, 0.44004038f, 0.13170612f, 0.94541045f, 0.89349973f, 0.69245307f, 0.94978877f, 0.98776658f, 0.79445884f, 0.30607409f, 0.58264961f, 0.37980538f, 0.41810784f, 0.48903038f, 0.51615888f, 0.57682794f, 0.82481897f, 0.78341080f, 0.48446465f, 0.17447931f, 0.71125424f, 0.30263851f, 0.70675352f, 0.03215584f, 0.92381065f, 0.22343694f, 0.08851149f, 0.91402490f, 0.70074717f, 0.30912192f, 0.37723206f, 0.97579397f, 0.23554587f, 0.95939133f, 0.41565709f, 0.01741416f, 0.58362787f, 0.22106662f, 0.89065537f, 0.31900249f, 0.41280911f, 0.67947610f, 0.04545590f, 0.15352812f, 0.85412524f, 0.84933222f, 0.80000225f, 0.93147073f, 0.70094105f, 0.69269875f, 0.95282194f, 0.65913582f, 0.79186874f, 0.59855248f, 0.39707430f, 0.95126239f, 0.15618217f, 0.33446689f, 0.98123758f, 0.84770758f, 0.98081012f, 0.54427413f, 0.18728519f, 0.89792955f, 0.53360126f, 0.72812986f, 0.13307744f, 0.51217443f, 0.66708084f, 0.29416915f, 0.31298995f, 0.39155037f, 0.29288291f, 0.87063305f, 0.61759154f, 0.73723332f, 0.37167635f, 0.82122716f, 0.22937430f, 0.76570536f, 0.47911792f, 0.02826214f, 0.94277323f, 0.59945469f, 0.19042060f, 0.68173155f, 0.82771295f, 0.95649538f, 0.40833101f, 0.90838542f, 0.55245881f, 0.49011012f, 0.36773444f, 0.34513527f, 0.42050683f, 0.16113964f, 0.30969388f, 0.27174174f, 0.12117655f, 0.35270175f, 0.81967867f, 0.63723136f, 0.84309389f, 0.71822576f, 0.84883484f, 0.32306117f, 0.08176457f, 0.56175486f, 0.34892198f, 0.09306929f, 0.85437582f, 0.13925577f, 0.48629188f, 0.29923539f}); @@ -1967,6 +1967,84 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, upsampling2d_bp_1) { + const int bS=1, iH=2,iW=2, iC=1; + const int factorH=2, factorW=2; + const int isNCHW = 1; // data format, default is NCHW + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}); + gradO = 1.; + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); + expGradI = 4.; + + nd4j::ops::upsampling2d_bp op; + auto results = op.execute({&input, &gradO}, {}, {isNCHW}); + auto* gradI = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, upsampling2d_bp_2) { + + const int bS=1, iH=2,iW=2, iC=1; + const int factorH=2, factorW=2; + const int isNCHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}); + gradO = 1.; + + auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC}); + expGradI = 4.; + + nd4j::ops::upsampling2d_bp op; + auto results = op.execute({&input, &gradO}, {}, {isNCHW}); + auto* gradI = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, upsampling2d_bp_3) { + + const int bS=1, iH=3,iW=3, iC=2; + const int factorH=2, factorW=2; + const int isNCHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32); + + NDArray gradO('c', {bS, iC, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, 0.44793984, + 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, 0.13505761, + 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, 0.32870287, + 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, + 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, 0.6994972, + 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, 0.5277549, + 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397}, nd4j::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iH, iW}, {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, 1.0523045, 1.3174672, 1.9263644, + 1.090545, 1.9094483, 1.3611296, 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, nd4j::DataType::FLOAT32); + + nd4j::ops::upsampling2d_bp op; + auto results = op.execute({&input, &gradO}, {}, {isNCHW}); + auto* gradI = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + delete results; +} #endif //LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 04f9a7570..d0c597cc5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -640,40 +640,6 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) { delete block; } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, MergeMaxTest1) { - - auto x = NDArrayFactory::create_('c', {5, 5}); - auto y = NDArrayFactory::create_('c', {5, 5}); - auto z = NDArrayFactory::create_('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x->assign(3); - y->assign(1); - z->assign(2); - exp.assign(3); - - auto zu = NDArrayFactory::create('c', {5, 5}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(-3, z); - variableSpace->putVariable(1, new Variable(NDArrayFactory::create_('c', {5, 5}))); - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1, -2, -3}); - - nd4j::ops::mergemax merge; - - merge.execute(block); - - auto res = variableSpace->getVariable(1)->getNDArray(); - - ASSERT_TRUE(res->equalsTo(&exp)); - - delete block; - delete variableSpace; -} - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MergeAvgTest1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index d4f422461..8c484268e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -845,5 +845,44 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) { delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergemax_1) { + + NDArray x1('c', {5, 5}, nd4j::DataType::FLOAT32); + NDArray x2('c', {5, 5}, nd4j::DataType::FLOAT32); + NDArray x3('c', {5, 5}, nd4j::DataType::FLOAT32); + NDArray e('c', {5, 5}, nd4j::DataType::FLOAT32); + x1.assign(3); + x2.assign(1); + x3.assign(2); + e.assign(3); + + + nd4j::ops::mergemax op; + auto result = op.execute({&x1, &x2, &x3}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + // z->printBuffer(); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergemax_2) { + + NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32); + NDArray x2('c', {1, 1}, {1.}, nd4j::DataType::FLOAT32); + NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32); + + nd4j::ops::mergemax op; + auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {}); + + ASSERT_EQ(20, status); +} + diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index bc716cc8e..62d297a50 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -1548,8 +1548,8 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); + z->printIndexedBuffer("Log ABS Output "); + exp.printIndexedBuffer("Log ABS Expected "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1671,6 +1671,40 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_010) { + + auto x = NDArrayFactory::create('c', {1, 5, 5}, { + 1., 0., 0., 0., 0., + 2., 1., 0., 0., 0., + 30., 2., 1., 0., 0., + 4., 3., 2., 1., 0., + 5., 4., 3., 2., 1., + }); + + auto exp = NDArrayFactory::create('c', {1, 5, 5}, { + 1.0, 0.0, 0.0, 0.0, 0., + -2.0, 1.0, 0., 0., 0., + -26.0, -2.0, 1, 0, 0., + 54.0, 1.0, -2.0, 1, 0., + -27.0, 0.0, 1.0, -2.0, 1. + }); + + nd4j::ops::matrix_inverse op; + auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("010 Output "); +// exp.printIndexedBuffer("010 Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_01) { @@ -1824,7 +1858,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_3) { - auto x = NDArrayFactory::create('c', {5, 5}, { + auto x = NDArrayFactory::create('c', {5, 5}, { 4., 0., 0., 0., 0., 4., 2., 0., 0., 0., 30., 2., 1., 0., 0., @@ -1832,7 +1866,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) { 15., 12., 9., 6., 3., }); - auto exp = NDArrayFactory::create('c', {5, 5}, { + auto exp = NDArrayFactory::create('c', {5, 5}, { 0.25, 0.0, 0.0, 0.0, 0.0, -0.50, 0.5, 0.0, 0.0, 0.0, -6.50, -1.0, 1.0, 0.0, 0.0, @@ -1841,13 +1875,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) { }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - exp.printIndexedBuffer("Expected "); - z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); +// z->printIndexedBuffer("Output "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1880,8 +1914,42 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("Output "); - exp.printIndexedBuffer("Expected "); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_04) { + + auto x = NDArrayFactory::create('c', {5, 5}, { + 1., 2., 30., 4., 5., + 0., 1., 2., 3., 4., + 0., 0., 1., 2., 3., + 0., 0., 0., 1., 2., + 0., 0., 0., 0., 1. + }); + + auto exp = NDArrayFactory::create('c', {5, 5}, { + 1.0, -2.0, -26.0, 54.0, -27.0, + 0.0, 1.0, -2.0, 1.0, 0.0, + 0.0, 0.0, 1.0, -2.0, 1.0, + 0.0, 0.0, 0.0, 1.0, -2.0, + 0.0, 0.0, 0.0, 0.0, 1.0 + }); + + nd4j::ops::matrix_inverse op; + auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 2c7962c5d..47cfa2584 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -9363,15 +9363,28 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include - + @Namespace("nd4j") @NoOffset public static class OpArgsHolder extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public OpArgsHolder(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public OpArgsHolder position(long position) { + return (OpArgsHolder)super.position(position); + } - + // default constructor + public OpArgsHolder() { super((Pointer)null); allocate(); } + private native void allocate(); + // copy constructor + public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef OpArgsHolder other); + + // constructor public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } @@ -9387,6 +9400,13 @@ public static final int PREALLOC_SIZE = 33554432; public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + // move constructor + + // assignment operator + public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other); + + // move assignment operator + public native @Const @ByRef NDArrayVector getInArrs(); public native @StdVector DoublePointer getTArgs(); @@ -9406,8 +9426,8 @@ public static final int PREALLOC_SIZE = 33554432; public native int getNumBArgs(); public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/); - public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); - + public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); + } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 43a05a33a..8e71816f8 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -9060,15 +9060,28 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include - + @Namespace("nd4j") @NoOffset public static class OpArgsHolder extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public OpArgsHolder(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public OpArgsHolder(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public OpArgsHolder position(long position) { + return (OpArgsHolder)super.position(position); + } - + // default constructor + public OpArgsHolder() { super((Pointer)null); allocate(); } + private native void allocate(); + // copy constructor + public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef OpArgsHolder other); + + // constructor public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } @@ -9084,6 +9097,13 @@ public static final int PREALLOC_SIZE = 33554432; public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + // move constructor + + // assignment operator + public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other); + + // move assignment operator + public native @Const @ByRef NDArrayVector getInArrs(); public native @StdVector DoublePointer getTArgs(); @@ -9103,8 +9123,8 @@ public static final int PREALLOC_SIZE = 33554432; public native int getNumBArgs(); public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/); - public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); - + public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index d7f5746cc..c2f5dedc5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -45,6 +45,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.nativeblas.NativeOpsHolder; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import static org.junit.Assert.*; @@ -548,6 +549,8 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(op); //Execution is OK } + + @Test public void testDepthwise(){ INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8); @@ -625,6 +628,49 @@ public class CustomOpsTests extends BaseNd4jTest { System.out.println(out); } + @Test + public void testUpsampling2dBackprop(){ + + Nd4j.getRandom().setSeed(12345); + int c = 2; + int[] sz = {2,2}; + long[] inSize = {1, c, 3, 3}; + INDArray eps = Nd4j.rand(DataType.FLOAT, 1, c, sz[0] * inSize[2], sz[1] * inSize[3]); + + INDArray input = Nd4j.create(inSize); //Unused, not sure why this is even an arg... + INDArray exp = Nd4j.create(DataType.FLOAT, inSize); + + for( int ch=0; ch