diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 726549415..1404afc96 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -2462,7 +2462,7 @@ double NDArray::getTrace() const { double sum = 0.; -PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) for(int i = 0; i < minDim; ++i) sum += e(i * offset); diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index a79f81612..2a843f956 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -100,7 +100,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char std::vector coords(zRank); - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) for (Nd4jLong i = 0; i < zLen; ++i) { shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data()); @@ -141,7 +141,7 @@ void NDArray::setIdentity() { minDim = shape[i]; float v = 1.0f; - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) for(int i = 0; i < minDim; ++i) templatedSet(buffer(), i*offset, this->dataType(), &v); } diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h index 3b5627eec..bda04414f 100644 --- a/libnd4j/include/helpers/Loops.h +++ b/libnd4j/include/helpers/Loops.h @@ -922,7 +922,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWS1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -944,7 +944,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWSNONZERO: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -966,7 +966,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -990,7 +990,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK2: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1016,7 +1016,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK3: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1044,7 +1044,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK4: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1074,7 +1074,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK5: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1111,7 +1111,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -1135,7 +1135,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, uint castYTadShapeInfo[MAX_RANK]; const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -1199,7 +1199,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWS1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1224,7 +1224,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWSNONZERO: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1249,7 +1249,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1276,7 +1276,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK2: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1305,7 +1305,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK3: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1336,7 +1336,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK4: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1369,7 +1369,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK5: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1409,7 +1409,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1435,7 +1435,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, uint castYTadShapeInfo[MAX_RANK]; const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index d17d2c021..fbf2fbc20 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -40,7 +40,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c const bool flagA = (flagC && transA) || (!flagC && !transA); const bool flagB = (flagC && transB) || (!flagC && !transB); - // PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // for(uint row = 0; row < M; ++row) { // T3* c = flagC ? (C + row) : (C + row * ldc); @@ -74,7 +74,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c // } // } - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2)) for(uint row = 0; row < M; ++row) { for(uint col = 0; col < N; ++col) { @@ -108,7 +108,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double const bool flagA = aOrder == 'f'; - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) for(int row = 0; row < M; ++row) { T3* y = Y + row * incy; @@ -139,7 +139,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX, T3 alphaZ(alpha), betaZ(beta); T3 sum = 0; - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) for(int i = 0; i < length; ++i) sum = sum + X[i * incx] * Y[i * incy]; diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp new file mode 100644 index 000000000..0bdb9503d --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_cyclic_shift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift); + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type") + + return Status::OK(); + } + + DECLARE_TYPES(cyclic_shift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp new file mode 100644 index 000000000..d64e808f4 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_shift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type") + + helpers::shift_bits(block.launchContext(), *input, *output, shift); + + return Status::OK(); + } + + DECLARE_TYPES(shift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index 466431c83..001c836dd 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -28,13 +28,36 @@ namespace nd4j { /** * This operation toggles individual bits of each element in array * - * PLEASE NOTE: This operation is possible only on integer datatypes + * PLEASE NOTE: This operation is possible only on integer data types * * @tparam T */ #if NOT_EXCLUDED(OP_toggle_bits) DECLARE_OP(toggle_bits, -1, -1, true); #endif + + + /** + * This operation shift individual bits of each element in array + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_shift_bits) + DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2); + #endif + + /** + * This operation shift individual bits of each element in array + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_cyclic_shift_bits) + DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2); + #endif } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index 20d9dd05d..c1d01930c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -35,8 +35,8 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind if(outRank == 1) { -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { Nd4jLong idx = indices.e(i); @@ -54,8 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) std::vector dimsToExcludeUpd(sizeOfDims); std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { NDArray outSubArr = output(indices.e(i), std::vector({0})); @@ -76,8 +76,8 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i if(outRank == 1) { -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { Nd4jLong idx = indices.e(i); @@ -93,8 +93,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); std::vector idxRangeOut(2*outRank, 0); -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut)) -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided) firstprivate(idxRangeOut)) for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) { NDArray indSubArr = indices(i, dimsToExcludeInd); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 261a742da..6ebfd9b07 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -479,7 +479,7 @@ namespace helpers { for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors->at(fi->first); outputT->assign(listOfTensors->at(fi->second.at(0))); - auto loopSize = fi->second.size(); + Nd4jLong loopSize = fi->second.size(); PRAGMA_OMP_PARALLEL_FOR for (Nd4jLong idx = 1; idx < loopSize; ++idx) { auto current = listOfTensors->at(fi->second.at(idx)); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp new file mode 100644 index 000000000..d9229faaa --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x << shift; + }; + + input.applyLambda(lambda, &output); + } + + void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x << shift | x >> step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp index 3536f9f62..9b96f34d2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp @@ -562,7 +562,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) { std::vector coords(maxRank); - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) for (Nd4jLong i = 0; i < zLen; ++i) { Nd4jLong *zCoordStart, *xCoordStart; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu new file mode 100644 index 000000000..bb5902c54 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x << shift; + }; + + input.applyLambda(lambda, &output); + } + + void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x << shift | x >> step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/shift.h b/libnd4j/include/ops/declarable/helpers/shift.h new file mode 100644 index 000000000..e3d5f40e2 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/shift.h @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef DEV_TESTS_SHIFT_H +#define DEV_TESTS_SHIFT_H + +#include +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + + void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + } + } +} + +#endif //DEV_TESTS_SHIFT_H diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 81e441477..bcbd1de8c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -33,8 +33,8 @@ class DeclarableOpsTests13 : public testing::Test { public: DeclarableOpsTests13() { - printf("\n"); - fflush(stdout); + //printf("\n"); + //fflush(stdout); } }; @@ -103,8 +103,9 @@ TEST_F(DeclarableOpsTests13, test_argmax_edge_1) { nd4j::ops::argmax op; auto result = op.execute(ctx); + ASSERT_EQ(Status::OK(), result); - nd4j_printf("Done\n",""); + //nd4j_printf("Done\n",""); delete ctx; } @@ -258,7 +259,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { ASSERT_EQ(result->status(), Status::OK()); - result->at(0)->printBuffer("Output"); + //result->at(0)->printBuffer("Output"); ASSERT_TRUE(exp1.equalsTo(result->at(0))); delete result; } @@ -306,8 +307,8 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { //nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); ASSERT_EQ(result->status(), Status::OK()); - result->at(0)->printBuffer("Output"); - exp.printBuffer("Expect"); + //result->at(0)->printBuffer("Output"); + //exp.printBuffer("Expect"); //result->at(0)->printShapeInfo("Shape output"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; @@ -327,7 +328,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { nd4j::ops::barnes_symmetrized op; auto result = op.execute({&rows, &cols, &vals}, {}, {1}); ASSERT_EQ(result->status(), Status::OK()); - result->at(2)->printBuffer("Symmetrized1"); + //result->at(2)->printBuffer("Symmetrized1"); ASSERT_TRUE(exp.equalsTo(result->at(2))); delete result; @@ -346,7 +347,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { nd4j::ops::barnes_symmetrized op; auto result = op.execute({&rows, &cols, &vals}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); - result->at(2)->printBuffer("Symmetrized2"); + //result->at(2)->printBuffer("Symmetrized2"); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); ASSERT_TRUE(exp.equalsTo(result->at(2))); delete result; @@ -365,7 +366,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { nd4j::ops::barnes_symmetrized op; auto result = op.execute({&rows, &cols, &vals}, {}, {11}); ASSERT_EQ(result->status(), Status::OK()); - result->at(2)->printBuffer("Symmetrized3"); + //result->at(2)->printBuffer("Symmetrized3"); //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); //ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -390,10 +391,10 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { auto result = op.execute({&rows, &cols, &vals}, {}, {11}); ASSERT_EQ(result->status(), Status::OK()); auto res = result->at(2); - res->printBuffer("Symmetrized4"); - exp4.printBuffer("Expected sym"); - nd4j_printf("Total res is {1, %lld}\n", res->lengthOf()); - nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf()); + // res->printBuffer("Symmetrized4"); + // exp4.printBuffer("Expected sym"); + // nd4j_printf("Total res is {1, %lld}\n", res->lengthOf()); + // nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf()); //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); @@ -619,3 +620,38 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) { delete results; } + +TEST_F(DeclarableOpsTests13, shift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + e.assign(512); + + nd4j::ops::shift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + e.assign(512); + + nd4j::ops::cyclic_shift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} +