[WIP] bitwise ops (#115)
* - cyclic_shift_bits + test - shift_bits + test Signed-off-by: raver119 <raver119@gmail.com> * OMP_IF replacement Signed-off-by: raver119 <raver119@gmail.com>master
parent
59cba587f4
commit
6264530dd8
|
@ -2462,7 +2462,7 @@ double NDArray::getTrace() const {
|
|||
|
||||
double sum = 0.;
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
for(int i = 0; i < minDim; ++i)
|
||||
sum += e<double>(i * offset);
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
|
|||
|
||||
std::vector<Nd4jLong> coords(zRank);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||
|
||||
shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data());
|
||||
|
@ -141,7 +141,7 @@ void NDArray::setIdentity() {
|
|||
minDim = shape[i];
|
||||
|
||||
float v = 1.0f;
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
for(int i = 0; i < minDim; ++i)
|
||||
templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
|
||||
}
|
||||
|
|
|
@ -922,7 +922,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::EWS1: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; ++i) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -944,7 +944,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::EWSNONZERO: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; ++i) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -966,7 +966,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK1: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; i++) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -990,7 +990,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK2: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; i++) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -1016,7 +1016,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK3: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; i++) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -1044,7 +1044,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK4: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; i++) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -1074,7 +1074,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK5: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; i++) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -1111,7 +1111,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
|
||||
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; ++i) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -1135,7 +1135,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
uint castYTadShapeInfo[MAX_RANK];
|
||||
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint i = 0; i < zLen; ++i) {
|
||||
|
||||
extraParams[0] = param0;
|
||||
|
@ -1199,7 +1199,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::EWS1: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1224,7 +1224,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::EWSNONZERO: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1249,7 +1249,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK1: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1276,7 +1276,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK2: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1305,7 +1305,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK3: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1336,7 +1336,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK4: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1369,7 +1369,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
//*********************************************//
|
||||
case LoopKind::RANK5: {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1409,7 +1409,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
|
||||
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
@ -1435,7 +1435,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|||
uint castYTadShapeInfo[MAX_RANK];
|
||||
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
|
||||
for (uint ix = 0; ix < numXTads; ++ix) {
|
||||
for (uint iy = 0; iy < numYTads; ++iy) {
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
|
|||
const bool flagA = (flagC && transA) || (!flagC && !transA);
|
||||
const bool flagB = (flagC && transB) || (!flagC && !transB);
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
// for(uint row = 0; row < M; ++row) {
|
||||
|
||||
// T3* c = flagC ? (C + row) : (C + row * ldc);
|
||||
|
@ -74,7 +74,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
|
|||
// }
|
||||
// }
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
|
||||
for(uint row = 0; row < M; ++row) {
|
||||
for(uint col = 0; col < N; ++col) {
|
||||
|
||||
|
@ -108,7 +108,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
|
|||
|
||||
const bool flagA = aOrder == 'f';
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
for(int row = 0; row < M; ++row) {
|
||||
|
||||
T3* y = Y + row * incy;
|
||||
|
@ -139,7 +139,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
|
|||
T3 alphaZ(alpha), betaZ(beta);
|
||||
|
||||
T3 sum = 0;
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
||||
for(int i = 0; i < length; ++i)
|
||||
sum = sum + X[i * incx] * Y[i * incy];
|
||||
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing");
|
||||
|
||||
uint32_t shift = 0;
|
||||
if (block.width() > 1) {
|
||||
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||
} else if (block.numI() > 0) {
|
||||
shift = INT_ARG(0);
|
||||
};
|
||||
|
||||
helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift);
|
||||
|
||||
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(cyclic_shift_bits) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_shift_bits)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing");
|
||||
|
||||
uint32_t shift = 0;
|
||||
if (block.width() > 1) {
|
||||
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||
} else if (block.numI() > 0) {
|
||||
shift = INT_ARG(0);
|
||||
};
|
||||
|
||||
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
|
||||
|
||||
helpers::shift_bits(block.launchContext(), *input, *output, shift);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(shift_bits) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -35,6 +35,29 @@ namespace nd4j {
|
|||
#if NOT_EXCLUDED(OP_toggle_bits)
|
||||
DECLARE_OP(toggle_bits, -1, -1, true);
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* This operation shift individual bits of each element in array
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_shift_bits)
|
||||
DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation shift individual bits of each element in array
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
|
||||
DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -35,8 +35,8 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind
|
|||
|
||||
if(outRank == 1) {
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
|
||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||
|
||||
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
||||
|
@ -54,8 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
|||
std::vector<int> dimsToExcludeUpd(sizeOfDims);
|
||||
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
|
||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||
|
||||
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
|
||||
|
@ -76,8 +76,8 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i
|
|||
|
||||
if(outRank == 1) {
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
|
||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||
|
||||
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
||||
|
@ -93,8 +93,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
|||
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
||||
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut))
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided) firstprivate(idxRangeOut))
|
||||
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
|
||||
|
||||
NDArray indSubArr = indices(i, dimsToExcludeInd);
|
||||
|
|
|
@ -479,7 +479,7 @@ namespace helpers {
|
|||
for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
|
||||
auto outputT = listOfOutTensors->at(fi->first);
|
||||
outputT->assign(listOfTensors->at(fi->second.at(0)));
|
||||
auto loopSize = fi->second.size();
|
||||
Nd4jLong loopSize = fi->second.size();
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong idx = 1; idx < loopSize; ++idx) {
|
||||
auto current = listOfTensors->at(fi->second.at(idx));
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename T>
|
||||
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||
auto lambda = LAMBDA_T(x, shift) {
|
||||
return x << shift;
|
||||
};
|
||||
|
||||
input.applyLambda<T>(lambda, &output);
|
||||
}
|
||||
|
||||
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||
auto step = (sizeof(T) * 8) - shift;
|
||||
auto lambda = LAMBDA_T(x, shift, step) {
|
||||
return x << shift | x >> step;
|
||||
};
|
||||
|
||||
input.applyLambda<T>(lambda, &output);
|
||||
}
|
||||
|
||||
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -562,7 +562,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
|
|||
|
||||
std::vector<Nd4jLong> coords(maxRank);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||
|
||||
Nd4jLong *zCoordStart, *xCoordStart;
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename T>
|
||||
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||
auto lambda = LAMBDA_T(x, shift) {
|
||||
return x << shift;
|
||||
};
|
||||
|
||||
input.applyLambda(lambda, &output);
|
||||
}
|
||||
|
||||
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
|
||||
auto step = (sizeof(T) * 8) - shift;
|
||||
auto lambda = LAMBDA_T(x, shift, step) {
|
||||
return x << shift | x >> step;
|
||||
};
|
||||
|
||||
input.applyLambda(lambda, &output);
|
||||
}
|
||||
|
||||
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
|
||||
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef DEV_TESTS_SHIFT_H
|
||||
#define DEV_TESTS_SHIFT_H
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <types/types.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
|
||||
|
||||
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_SHIFT_H
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -33,8 +33,8 @@ class DeclarableOpsTests13 : public testing::Test {
|
|||
public:
|
||||
|
||||
DeclarableOpsTests13() {
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
//printf("\n");
|
||||
//fflush(stdout);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -103,8 +103,9 @@ TEST_F(DeclarableOpsTests13, test_argmax_edge_1) {
|
|||
|
||||
nd4j::ops::argmax op;
|
||||
auto result = op.execute(ctx);
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
|
||||
nd4j_printf("Done\n","");
|
||||
//nd4j_printf("Done\n","");
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
|
@ -258,7 +259,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) {
|
|||
|
||||
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
result->at(0)->printBuffer("Output");
|
||||
//result->at(0)->printBuffer("Output");
|
||||
ASSERT_TRUE(exp1.equalsTo(result->at(0)));
|
||||
delete result;
|
||||
}
|
||||
|
@ -306,8 +307,8 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
|
|||
|
||||
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
result->at(0)->printBuffer("Output");
|
||||
exp.printBuffer("Expect");
|
||||
//result->at(0)->printBuffer("Output");
|
||||
//exp.printBuffer("Expect");
|
||||
//result->at(0)->printShapeInfo("Shape output");
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
delete result;
|
||||
|
@ -327,7 +328,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) {
|
|||
nd4j::ops::barnes_symmetrized op;
|
||||
auto result = op.execute({&rows, &cols, &vals}, {}, {1});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
result->at(2)->printBuffer("Symmetrized1");
|
||||
//result->at(2)->printBuffer("Symmetrized1");
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
||||
|
||||
delete result;
|
||||
|
@ -346,7 +347,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) {
|
|||
nd4j::ops::barnes_symmetrized op;
|
||||
auto result = op.execute({&rows, &cols, &vals}, {}, {3});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
result->at(2)->printBuffer("Symmetrized2");
|
||||
//result->at(2)->printBuffer("Symmetrized2");
|
||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(2)));
|
||||
delete result;
|
||||
|
@ -365,7 +366,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
|
|||
nd4j::ops::barnes_symmetrized op;
|
||||
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
result->at(2)->printBuffer("Symmetrized3");
|
||||
//result->at(2)->printBuffer("Symmetrized3");
|
||||
//exp.printBuffer("EXPect symm3");
|
||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
|
@ -390,10 +391,10 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
|
|||
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto res = result->at(2);
|
||||
res->printBuffer("Symmetrized4");
|
||||
exp4.printBuffer("Expected sym");
|
||||
nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
|
||||
nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
|
||||
// res->printBuffer("Symmetrized4");
|
||||
// exp4.printBuffer("Expected sym");
|
||||
// nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
|
||||
// nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
|
||||
|
||||
//exp.printBuffer("EXPect symm3");
|
||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||
|
@ -619,3 +620,38 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests13, shift_bits_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto e = x.ulike();
|
||||
x.assign(32);
|
||||
e.assign(512);
|
||||
|
||||
nd4j::ops::shift_bits op;
|
||||
auto result = op.execute({&x}, {}, {4});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto e = x.ulike();
|
||||
x.assign(32);
|
||||
e.assign(512);
|
||||
|
||||
nd4j::ops::cyclic_shift_bits op;
|
||||
auto result = op.execute({&x}, {}, {4});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue