From 987bb80c465891fb2e03981e67682a027d2e4717 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 15 Aug 2019 20:35:15 +0300 Subject: [PATCH] [WIP] right shift ops (#118) * right shift ops Signed-off-by: raver119 * typo Signed-off-by: raver119 * rotr test Signed-off-by: raver119 --- .../generic/bitwise/cyclic_rshift.cpp | 58 +++++++++++++++++++ .../ops/declarable/generic/bitwise/rshift.cpp | 58 +++++++++++++++++++ .../ops/declarable/generic/bitwise/shift.cpp | 2 +- .../include/ops/declarable/headers/bitwise.h | 26 ++++++++- .../ops/declarable/helpers/cpu/shift.cpp | 27 +++++++++ .../ops/declarable/helpers/cuda/shift.cu | 27 +++++++++ .../include/ops/declarable/helpers/shift.h | 4 ++ .../layers_tests/DeclarableOpsTests13.cpp | 34 +++++++++++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 53 ++++++++++++++++- 9 files changed, 285 insertions(+), 4 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp create mode 100644 libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp new file mode 100644 index 000000000..2aac5c6f9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_cyclic_rshift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(cyclic_rshift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_rshift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + helpers::cyclic_rshift_bits(block.launchContext(), *input, *output, shift); + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_rshift_bits: can't shift beyond size of data type") + + return Status::OK(); + } + + DECLARE_TYPES(cyclic_rshift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp new file mode 100644 index 000000000..4068351a2 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_rshift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(rshift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "rshift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "rshift_bits: can't shift beyond size of data type") + + helpers::rshift_bits(block.launchContext(), *input, *output, shift); + + return Status::OK(); + } + + DECLARE_TYPES(rshift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp index d64e808f4..f79da1024 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -40,7 +40,7 @@ namespace nd4j { shift = INT_ARG(0); }; - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type") + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "shift_bits: can't shift beyond size of data type") helpers::shift_bits(block.launchContext(), *input, *output, shift); diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index 001c836dd..900d42816 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -38,7 +38,7 @@ namespace nd4j { /** - * This operation shift individual bits of each element in array + * This operation shift individual bits of each element in array to the left: << * * PLEASE NOTE: This operation is applicable only to integer data types * @@ -49,7 +49,18 @@ namespace nd4j { #endif /** - * This operation shift individual bits of each element in array + * This operation shift individual bits of each element in array to the right: >> + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_rshift_bits) + DECLARE_CONFIGURABLE_OP(rshift_bits, 1, 1, true, 0, -2); + #endif + + /** + * This operation shift individual bits of each element in array, shifting to the left * * PLEASE NOTE: This operation is applicable only to integer data types * @@ -58,6 +69,17 @@ namespace nd4j { #if NOT_EXCLUDED(OP_cyclic_shift_bits) DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2); #endif + + /** + * This operation shift individual bits of each element in array, shifting to the right + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_cyclic_rshift_bits) + DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2); + #endif } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp index d9229faaa..7a9b77b66 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp @@ -23,6 +23,19 @@ namespace nd4j { namespace ops { namespace helpers { + template + void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x >> shift; + }; + + input.applyLambda(lambda, &output); + } + + void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + template void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { auto lambda = LAMBDA_T(x, shift) { @@ -36,6 +49,20 @@ namespace nd4j { BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); } + template + void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x >> shift | x << step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + template void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { auto step = (sizeof(T) * 8) - shift; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu index bb5902c54..49d388b2a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu @@ -23,6 +23,19 @@ namespace nd4j { namespace ops { namespace helpers { + template + void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x >> shift; + }; + + input.applyLambda(lambda, &output); + } + + void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + template void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { auto lambda = LAMBDA_T(x, shift) { @@ -36,6 +49,20 @@ namespace nd4j { BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); } + template + void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x >> shift | x << step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + template void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { auto step = (sizeof(T) * 8) - shift; diff --git a/libnd4j/include/ops/declarable/helpers/shift.h b/libnd4j/include/ops/declarable/helpers/shift.h index e3d5f40e2..e07a0e992 100644 --- a/libnd4j/include/ops/declarable/helpers/shift.h +++ b/libnd4j/include/ops/declarable/helpers/shift.h @@ -28,8 +28,12 @@ namespace nd4j { namespace ops { namespace helpers { + void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index bcbd1de8c..014719270 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -638,6 +638,23 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) { delete result; } +TEST_F(DeclarableOpsTests13, rshift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + e.assign(32); + + nd4j::ops::rshift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { auto x = NDArrayFactory::create('c', {5}); auto e = x.ulike(); @@ -655,3 +672,20 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { delete result; } +TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + e.assign(32); + + nd4j::ops::cyclic_rshift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 52fe5c652..dc0644d47 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -21722,7 +21722,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This operation toggles individual bits of each element in array * - * PLEASE NOTE: This operation is possible only on integer datatypes + * PLEASE NOTE: This operation is possible only on integer data types * * \tparam T */ @@ -21743,6 +21743,57 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif + + + /** + * This operation shift individual bits of each element in array + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_shift_bits) + @Namespace("nd4j::ops") public static class shift_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shift_bits position(long position) { + return (shift_bits)super.position(position); + } + + public shift_bits() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * This operation shift individual bits of each element in array + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_cyclic_shift_bits) + @Namespace("nd4j::ops") public static class cyclic_shift_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cyclic_shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cyclic_shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cyclic_shift_bits position(long position) { + return (cyclic_shift_bits)super.position(position); + } + + public cyclic_shift_bits() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif