[WIP] bitwise ops (#115)
* - cyclic_shift_bits + test - shift_bits + test Signed-off-by: raver119 <raver119@gmail.com> * OMP_IF replacement Signed-off-by: raver119 <raver119@gmail.com>master
parent
59cba587f4
commit
6264530dd8
|
@ -2462,7 +2462,7 @@ double NDArray::getTrace() const {
|
||||||
|
|
||||||
double sum = 0.;
|
double sum = 0.;
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
for(int i = 0; i < minDim; ++i)
|
for(int i = 0; i < minDim; ++i)
|
||||||
sum += e<double>(i * offset);
|
sum += e<double>(i * offset);
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
|
||||||
|
|
||||||
std::vector<Nd4jLong> coords(zRank);
|
std::vector<Nd4jLong> coords(zRank);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
||||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data());
|
shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data());
|
||||||
|
@ -141,7 +141,7 @@ void NDArray::setIdentity() {
|
||||||
minDim = shape[i];
|
minDim = shape[i];
|
||||||
|
|
||||||
float v = 1.0f;
|
float v = 1.0f;
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
for(int i = 0; i < minDim; ++i)
|
for(int i = 0; i < minDim; ++i)
|
||||||
templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
|
templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
|
||||||
}
|
}
|
||||||
|
|
|
@ -922,7 +922,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::EWS1: {
|
case LoopKind::EWS1: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (uint i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -944,7 +944,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::EWSNONZERO: {
|
case LoopKind::EWSNONZERO: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (uint i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -966,7 +966,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK1: {
|
case LoopKind::RANK1: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (uint i = 0; i < zLen; i++) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -990,7 +990,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK2: {
|
case LoopKind::RANK2: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (uint i = 0; i < zLen; i++) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -1016,7 +1016,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK3: {
|
case LoopKind::RANK3: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (uint i = 0; i < zLen; i++) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -1044,7 +1044,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK4: {
|
case LoopKind::RANK4: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (uint i = 0; i < zLen; i++) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -1074,7 +1074,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK5: {
|
case LoopKind::RANK5: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (uint i = 0; i < zLen; i++) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -1111,7 +1111,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (uint i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -1135,7 +1135,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
uint castYTadShapeInfo[MAX_RANK];
|
uint castYTadShapeInfo[MAX_RANK];
|
||||||
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
|
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (uint i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
extraParams[0] = param0;
|
extraParams[0] = param0;
|
||||||
|
@ -1199,7 +1199,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::EWS1: {
|
case LoopKind::EWS1: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1224,7 +1224,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::EWSNONZERO: {
|
case LoopKind::EWSNONZERO: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1249,7 +1249,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK1: {
|
case LoopKind::RANK1: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1276,7 +1276,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK2: {
|
case LoopKind::RANK2: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1305,7 +1305,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK3: {
|
case LoopKind::RANK3: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1336,7 +1336,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK4: {
|
case LoopKind::RANK4: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1369,7 +1369,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK5: {
|
case LoopKind::RANK5: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1409,7 +1409,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
@ -1435,7 +1435,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
||||||
uint castYTadShapeInfo[MAX_RANK];
|
uint castYTadShapeInfo[MAX_RANK];
|
||||||
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
|
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
|
||||||
const bool flagA = (flagC && transA) || (!flagC && !transA);
|
const bool flagA = (flagC && transA) || (!flagC && !transA);
|
||||||
const bool flagB = (flagC && transB) || (!flagC && !transB);
|
const bool flagB = (flagC && transB) || (!flagC && !transB);
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
// for(uint row = 0; row < M; ++row) {
|
// for(uint row = 0; row < M; ++row) {
|
||||||
|
|
||||||
// T3* c = flagC ? (C + row) : (C + row * ldc);
|
// T3* c = flagC ? (C + row) : (C + row * ldc);
|
||||||
|
@ -74,7 +74,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
|
||||||
for(uint row = 0; row < M; ++row) {
|
for(uint row = 0; row < M; ++row) {
|
||||||
for(uint col = 0; col < N; ++col) {
|
for(uint col = 0; col < N; ++col) {
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
|
||||||
|
|
||||||
const bool flagA = aOrder == 'f';
|
const bool flagA = aOrder == 'f';
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
for(int row = 0; row < M; ++row) {
|
for(int row = 0; row < M; ++row) {
|
||||||
|
|
||||||
T3* y = Y + row * incy;
|
T3* y = Y + row * incy;
|
||||||
|
@ -139,7 +139,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
|
||||||
T3 alphaZ(alpha), betaZ(beta);
|
T3 alphaZ(alpha), betaZ(beta);
|
||||||
|
|
||||||
T3 sum = 0;
|
T3 sum = 0;
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
||||||
for(int i = 0; i < length; ++i)
|
for(int i = 0; i < length; ++i)
|
||||||
sum = sum + X[i * incx] * Y[i * incy];
|
sum = sum + X[i * incx] * Y[i * incy];
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/shift.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing");
|
||||||
|
|
||||||
|
uint32_t shift = 0;
|
||||||
|
if (block.width() > 1) {
|
||||||
|
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||||
|
} else if (block.numI() > 0) {
|
||||||
|
shift = INT_ARG(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(cyclic_shift_bits) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_shift_bits)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/shift.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing");
|
||||||
|
|
||||||
|
uint32_t shift = 0;
|
||||||
|
if (block.width() > 1) {
|
||||||
|
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||||
|
} else if (block.numI() > 0) {
|
||||||
|
shift = INT_ARG(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
|
||||||
|
|
||||||
|
helpers::shift_bits(block.launchContext(), *input, *output, shift);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(shift_bits) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -28,13 +28,36 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* This operation toggles individual bits of each element in array
|
* This operation toggles individual bits of each element in array
|
||||||
*
|
*
|
||||||
* PLEASE NOTE: This operation is possible only on integer datatypes
|
* PLEASE NOTE: This operation is possible only on integer data types
|
||||||
*
|
*
|
||||||
* @tparam T
|
* @tparam T
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_toggle_bits)
|
#if NOT_EXCLUDED(OP_toggle_bits)
|
||||||
DECLARE_OP(toggle_bits, -1, -1, true);
|
DECLARE_OP(toggle_bits, -1, -1, true);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation shift individual bits of each element in array
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_shift_bits)
|
||||||
|
DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation shift individual bits of each element in array
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
|
||||||
|
DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,8 +35,8 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind
|
||||||
|
|
||||||
if(outRank == 1) {
|
if(outRank == 1) {
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
|
||||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||||
|
|
||||||
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
||||||
|
@ -54,8 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||||
std::vector<int> dimsToExcludeUpd(sizeOfDims);
|
std::vector<int> dimsToExcludeUpd(sizeOfDims);
|
||||||
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
|
||||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||||
|
|
||||||
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
|
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
|
||||||
|
@ -76,8 +76,8 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i
|
||||||
|
|
||||||
if(outRank == 1) {
|
if(outRank == 1) {
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
|
||||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||||
|
|
||||||
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
||||||
|
@ -93,8 +93,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||||
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
||||||
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
|
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided) firstprivate(idxRangeOut))
|
||||||
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
|
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
|
||||||
|
|
||||||
NDArray indSubArr = indices(i, dimsToExcludeInd);
|
NDArray indSubArr = indices(i, dimsToExcludeInd);
|
||||||
|
|
|
@ -479,7 +479,7 @@ namespace helpers {
|
||||||
for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
|
for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
|
||||||
auto outputT = listOfOutTensors->at(fi->first);
|
auto outputT = listOfOutTensors->at(fi->first);
|
||||||
outputT->assign(listOfTensors->at(fi->second.at(0)));
|
outputT->assign(listOfTensors->at(fi->second.at(0)));
|
||||||
auto loopSize = fi->second.size();
|
Nd4jLong loopSize = fi->second.size();
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for (Nd4jLong idx = 1; idx < loopSize; ++idx) {
|
for (Nd4jLong idx = 1; idx < loopSize; ++idx) {
|
||||||
auto current = listOfTensors->at(fi->second.at(idx));
|
auto current = listOfTensors->at(fi->second.at(idx));
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/shift.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename T>
|
||||||
|
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||||
|
auto lambda = LAMBDA_T(x, shift) {
|
||||||
|
return x << shift;
|
||||||
|
};
|
||||||
|
|
||||||
|
input.applyLambda<T>(lambda, &output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||||
|
auto step = (sizeof(T) * 8) - shift;
|
||||||
|
auto lambda = LAMBDA_T(x, shift, step) {
|
||||||
|
return x << shift | x >> step;
|
||||||
|
};
|
||||||
|
|
||||||
|
input.applyLambda<T>(lambda, &output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -562,7 +562,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
|
|
||||||
std::vector<Nd4jLong> coords(maxRank);
|
std::vector<Nd4jLong> coords(maxRank);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
||||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
Nd4jLong *zCoordStart, *xCoordStart;
|
Nd4jLong *zCoordStart, *xCoordStart;
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/shift.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename T>
|
||||||
|
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||||
|
auto lambda = LAMBDA_T(x, shift) {
|
||||||
|
return x << shift;
|
||||||
|
};
|
||||||
|
|
||||||
|
input.applyLambda(lambda, &output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||||
|
auto step = (sizeof(T) * 8) - shift;
|
||||||
|
auto lambda = LAMBDA_T(x, shift, step) {
|
||||||
|
return x << shift | x >> step;
|
||||||
|
};
|
||||||
|
|
||||||
|
input.applyLambda(lambda, &output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef DEV_TESTS_SHIFT_H
|
||||||
|
#define DEV_TESTS_SHIFT_H
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <types/types.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
|
||||||
|
|
||||||
|
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_SHIFT_H
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -33,8 +33,8 @@ class DeclarableOpsTests13 : public testing::Test {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
DeclarableOpsTests13() {
|
DeclarableOpsTests13() {
|
||||||
printf("\n");
|
//printf("\n");
|
||||||
fflush(stdout);
|
//fflush(stdout);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -103,8 +103,9 @@ TEST_F(DeclarableOpsTests13, test_argmax_edge_1) {
|
||||||
|
|
||||||
nd4j::ops::argmax op;
|
nd4j::ops::argmax op;
|
||||||
auto result = op.execute(ctx);
|
auto result = op.execute(ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
nd4j_printf("Done\n","");
|
//nd4j_printf("Done\n","");
|
||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,7 +259,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) {
|
||||||
|
|
||||||
|
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
result->at(0)->printBuffer("Output");
|
//result->at(0)->printBuffer("Output");
|
||||||
ASSERT_TRUE(exp1.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp1.equalsTo(result->at(0)));
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
@ -306,8 +307,8 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
|
||||||
|
|
||||||
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
|
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
result->at(0)->printBuffer("Output");
|
//result->at(0)->printBuffer("Output");
|
||||||
exp.printBuffer("Expect");
|
//exp.printBuffer("Expect");
|
||||||
//result->at(0)->printShapeInfo("Shape output");
|
//result->at(0)->printShapeInfo("Shape output");
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -327,7 +328,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) {
|
||||||
nd4j::ops::barnes_symmetrized op;
|
nd4j::ops::barnes_symmetrized op;
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {1});
|
auto result = op.execute({&rows, &cols, &vals}, {}, {1});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
result->at(2)->printBuffer("Symmetrized1");
|
//result->at(2)->printBuffer("Symmetrized1");
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -346,7 +347,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) {
|
||||||
nd4j::ops::barnes_symmetrized op;
|
nd4j::ops::barnes_symmetrized op;
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {3});
|
auto result = op.execute({&rows, &cols, &vals}, {}, {3});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
result->at(2)->printBuffer("Symmetrized2");
|
//result->at(2)->printBuffer("Symmetrized2");
|
||||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||||
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -365,7 +366,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
|
||||||
nd4j::ops::barnes_symmetrized op;
|
nd4j::ops::barnes_symmetrized op;
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
result->at(2)->printBuffer("Symmetrized3");
|
//result->at(2)->printBuffer("Symmetrized3");
|
||||||
//exp.printBuffer("EXPect symm3");
|
//exp.printBuffer("EXPect symm3");
|
||||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||||
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
|
@ -390,10 +391,10 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
|
||||||
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
auto res = result->at(2);
|
auto res = result->at(2);
|
||||||
res->printBuffer("Symmetrized4");
|
// res->printBuffer("Symmetrized4");
|
||||||
exp4.printBuffer("Expected sym");
|
// exp4.printBuffer("Expected sym");
|
||||||
nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
|
// nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
|
||||||
nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
|
// nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
|
||||||
|
|
||||||
//exp.printBuffer("EXPect symm3");
|
//exp.printBuffer("EXPect symm3");
|
||||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||||
|
@ -619,3 +620,38 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests13, shift_bits_1) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {5});
|
||||||
|
auto e = x.ulike();
|
||||||
|
x.assign(32);
|
||||||
|
e.assign(512);
|
||||||
|
|
||||||
|
nd4j::ops::shift_bits op;
|
||||||
|
auto result = op.execute({&x}, {}, {4});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {5});
|
||||||
|
auto e = x.ulike();
|
||||||
|
x.assign(32);
|
||||||
|
e.assign(512);
|
||||||
|
|
||||||
|
nd4j::ops::cyclic_shift_bits op;
|
||||||
|
auto result = op.execute({&x}, {}, {4});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue