[WIP] bitwise ops (#115)

* - cyclic_shift_bits + test
- shift_bits + test

Signed-off-by: raver119 <raver119@gmail.com>

* OMP_IF replacement

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-15 11:49:50 +03:00 committed by GitHub
parent 59cba587f4
commit 6264530dd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 371 additions and 50 deletions

View File

@ -2462,7 +2462,7 @@ double NDArray::getTrace() const {
double sum = 0.;
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int i = 0; i < minDim; ++i)
sum += e<double>(i * offset);

View File

@ -100,7 +100,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
std::vector<Nd4jLong> coords(zRank);
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data());
@ -141,7 +141,7 @@ void NDArray::setIdentity() {
minDim = shape[i];
float v = 1.0f;
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int i = 0; i < minDim; ++i)
templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
}

View File

@ -922,7 +922,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWS1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -944,7 +944,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWSNONZERO: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -966,7 +966,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -990,7 +990,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK2: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1016,7 +1016,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK3: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1044,7 +1044,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK4: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1074,7 +1074,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK5: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1111,7 +1111,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -1135,7 +1135,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
uint castYTadShapeInfo[MAX_RANK];
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -1199,7 +1199,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWS1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1224,7 +1224,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWSNONZERO: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1249,7 +1249,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1276,7 +1276,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK2: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1305,7 +1305,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK3: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1336,7 +1336,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK4: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1369,7 +1369,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK5: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1409,7 +1409,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1435,7 +1435,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
uint castYTadShapeInfo[MAX_RANK];
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {

View File

@ -40,7 +40,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
const bool flagA = (flagC && transA) || (!flagC && !transA);
const bool flagB = (flagC && transB) || (!flagC && !transB);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
// for(uint row = 0; row < M; ++row) {
// T3* c = flagC ? (C + row) : (C + row * ldc);
@ -74,7 +74,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
// }
// }
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
for(uint row = 0; row < M; ++row) {
for(uint col = 0; col < N; ++col) {
@ -108,7 +108,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
const bool flagA = aOrder == 'f';
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int row = 0; row < M; ++row) {
T3* y = Y + row * incy;
@ -139,7 +139,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
T3 alphaZ(alpha), betaZ(beta);
T3 sum = 0;
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
for(int i = 0; i < length; ++i)
sum = sum + X[i * incx] * Y[i * incy];

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing");
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift);
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
return Status::OK();
}
DECLARE_TYPES(cyclic_shift_bits) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_shift_bits)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing");
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
helpers::shift_bits(block.launchContext(), *input, *output, shift);
return Status::OK();
}
DECLARE_TYPES(shift_bits) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -35,6 +35,29 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_toggle_bits)
DECLARE_OP(toggle_bits, -1, -1, true);
#endif
/**
* This operation shift individual bits of each element in array
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_shift_bits)
DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2);
#endif
/**
* This operation shift individual bits of each element in array
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2);
#endif
}
}

View File

@ -35,8 +35,8 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind
if(outRank == 1) {
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) {
Nd4jLong idx = indices.e<Nd4jLong>(i);
@ -54,8 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
std::vector<int> dimsToExcludeUpd(sizeOfDims);
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) {
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
@ -76,8 +76,8 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i
if(outRank == 1) {
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) {
Nd4jLong idx = indices.e<Nd4jLong>(i);
@ -93,8 +93,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided) firstprivate(idxRangeOut))
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
NDArray indSubArr = indices(i, dimsToExcludeInd);

View File

@ -479,7 +479,7 @@ namespace helpers {
for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
auto outputT = listOfOutTensors->at(fi->first);
outputT->assign(listOfTensors->at(fi->second.at(0)));
auto loopSize = fi->second.size();
Nd4jLong loopSize = fi->second.size();
PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong idx = 1; idx < loopSize; ++idx) {
auto current = listOfTensors->at(fi->second.at(idx));

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto lambda = LAMBDA_T(x, shift) {
return x << shift;
};
input.applyLambda<T>(lambda, &output);
}
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto step = (sizeof(T) * 8) - shift;
auto lambda = LAMBDA_T(x, shift, step) {
return x << shift | x >> step;
};
input.applyLambda<T>(lambda, &output);
}
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
}
}
}

View File

@ -562,7 +562,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
std::vector<Nd4jLong> coords(maxRank);
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
for (Nd4jLong i = 0; i < zLen; ++i) {
Nd4jLong *zCoordStart, *xCoordStart;

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto lambda = LAMBDA_T(x, shift) {
return x << shift;
};
input.applyLambda(lambda, &output);
}
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto step = (sizeof(T) * 8) - shift;
auto lambda = LAMBDA_T(x, shift, step) {
return x << shift | x >> step;
};
input.applyLambda(lambda, &output);
}
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
}
}
}

View File

@ -0,0 +1,38 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef DEV_TESTS_SHIFT_H
#define DEV_TESTS_SHIFT_H
#include <op_boilerplate.h>
#include <types/types.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
}
}
}
#endif //DEV_TESTS_SHIFT_H

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -33,8 +33,8 @@ class DeclarableOpsTests13 : public testing::Test {
public:
DeclarableOpsTests13() {
printf("\n");
fflush(stdout);
//printf("\n");
//fflush(stdout);
}
};
@ -103,8 +103,9 @@ TEST_F(DeclarableOpsTests13, test_argmax_edge_1) {
nd4j::ops::argmax op;
auto result = op.execute(ctx);
ASSERT_EQ(Status::OK(), result);
nd4j_printf("Done\n","");
//nd4j_printf("Done\n","");
delete ctx;
}
@ -258,7 +259,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) {
ASSERT_EQ(result->status(), Status::OK());
result->at(0)->printBuffer("Output");
//result->at(0)->printBuffer("Output");
ASSERT_TRUE(exp1.equalsTo(result->at(0)));
delete result;
}
@ -306,8 +307,8 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
ASSERT_EQ(result->status(), Status::OK());
result->at(0)->printBuffer("Output");
exp.printBuffer("Expect");
//result->at(0)->printBuffer("Output");
//exp.printBuffer("Expect");
//result->at(0)->printShapeInfo("Shape output");
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
@ -327,7 +328,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) {
nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {1});
ASSERT_EQ(result->status(), Status::OK());
result->at(2)->printBuffer("Symmetrized1");
//result->at(2)->printBuffer("Symmetrized1");
ASSERT_TRUE(exp.equalsTo(result->at(2)));
delete result;
@ -346,7 +347,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) {
nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {3});
ASSERT_EQ(result->status(), Status::OK());
result->at(2)->printBuffer("Symmetrized2");
//result->at(2)->printBuffer("Symmetrized2");
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
ASSERT_TRUE(exp.equalsTo(result->at(2)));
delete result;
@ -365,7 +366,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
ASSERT_EQ(result->status(), Status::OK());
result->at(2)->printBuffer("Symmetrized3");
//result->at(2)->printBuffer("Symmetrized3");
//exp.printBuffer("EXPect symm3");
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
@ -390,10 +391,10 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
ASSERT_EQ(result->status(), Status::OK());
auto res = result->at(2);
res->printBuffer("Symmetrized4");
exp4.printBuffer("Expected sym");
nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
// res->printBuffer("Symmetrized4");
// exp4.printBuffer("Expected sym");
// nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
// nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
//exp.printBuffer("EXPect symm3");
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
@ -619,3 +620,38 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
delete results;
}
TEST_F(DeclarableOpsTests13, shift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(32);
e.assign(512);
nd4j::ops::shift_bits op;
auto result = op.execute({&x}, {}, {4});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(32);
e.assign(512);
nd4j::ops::cyclic_shift_bits op;
auto result = op.execute({&x}, {}, {4});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}