From fe22bd5726779d22a533c923b9b639443bebd0cc Mon Sep 17 00:00:00 2001 From: AbdelRauf Date: Sun, 28 Feb 2021 19:19:59 +0100 Subject: [PATCH] Compare_and_bitpack: It was reimplemented. now the last dimension should be divisible by 8 Signed-off-by: AbdelRauf --- .../parity_ops/compare_and_bitpack.cpp | 23 +- .../ops/declarable/headers/parity_ops.h | 8 +- .../helpers/cpu/compare_and_bitpack.cpp | 191 +++++++++++++++ .../helpers/cuda/compare_and_bitpack.cu | 191 +++++++++++++++ .../ops/declarable/helpers/transforms.h | 4 +- .../layers_tests/DeclarableOpsTests9.cpp | 223 ++++++++++++++++-- 6 files changed, 603 insertions(+), 37 deletions(-) create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/compare_and_bitpack.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/compare_and_bitpack.cu diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp index f62492a40..a0ccde304 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace sd { @@ -31,17 +32,9 @@ namespace sd { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector(), block.launchContext()); - BROADCAST_CHECK_EMPTY(x, y, (&z0)); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); - bitcast res; - auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false); - if (tZ != &z0) { - delete tZ; - } - - return status; + sd::ops::helpers::compareAndBitpack(block, *x, *y, *z); + return Status::OK(); } DECLARE_TYPES(compare_and_bitpack) { @@ -53,9 +46,15 @@ namespace sd { DECLARE_SHAPE_FN(compare_and_bitpack) { auto inShape = inputShape->at(0); + auto shapes = shape::shapeOf(inShape); + const int rank = shape::rank(inShape); + REQUIRE_TRUE(!shape::isScalar(inShape), 0, "Input should not be a scalar"); + std::vector shapeDims {shapes, shapes + rank}; + REQUIRE_TRUE(shapeDims[rank-1] % 8 ==0 , 0, "Last dimension of the input (which is %i) should be divisible by 8 ", shapeDims[rank-1]); + shapeDims[rank-1] = shapeDims[rank-1] / 8 ; DataType newType = DataType::UINT8; - - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType))); + auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(newType, shape::order(inShape), shapeDims); + return SHAPELIST(outputShape); } } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index b3363da9b..e06e1a894 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1908,15 +1908,15 @@ namespace sd { #endif /** - * compare_and_bitpack - compare with greater and pack result with uint8 + * compare_and_bitpack - Compare values of input to threshold and pack resulting bits into a uint8 * * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - threshold + * 0 - NDArray (input). Note: last dimension should be divisibly by 8 + * 1 - 0D Tensor - threshold to compare against. Note: when input and threshold is bool type, the threshold is ignored * * * output: - * 0 - NDArray with the same shape as input and type uint8 + * 0 - NDArray with the shape as {input.dim0,...input.dimLast/8} and type uint8 */ #if NOT_EXCLUDED(OP_compare_and_bitpack) DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compare_and_bitpack.cpp new file mode 100644 index 000000000..3c3a9c509 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/compare_and_bitpack.cpp @@ -0,0 +1,191 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + // + // @author AbdelRauf + // + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + namespace helpers { + + + template + uint8_t pack(const X* buff, const X& threshold){ + uint8_t res; + res = (buff[0] > threshold) << 7; + res = res | ((buff[1] > threshold) << 6); + res = res | ((buff[2] > threshold) << 5); + res = res | ((buff[3] > threshold) << 4); + res = res | ((buff[4] > threshold) << 3); + res = res | ((buff[5] > threshold) << 2); + res = res | ((buff[6] > threshold) << 1); + res = res | (buff[7] > threshold); + return res; + } + + template<> + uint8_t pack(const bool* buff, const bool &threshold){ + //ignore threshold + uint8_t res; + res = buff[0] << 7; + res = res | (buff[1] << 6); + res = res | (buff[2] << 5); + res = res | (buff[3] << 4); + res = res | (buff[4] << 3); + res = res | (buff[5] << 2); + res = res | (buff[6] << 1); + res = res | buff[7] ; + return res; + } + + template + uint8_t pack(const X* buff, int stride, const X& threshold){ + uint8_t res; + res = (buff[0] > threshold) << 7; + res = res | ((buff[1*stride] > threshold) << 6); + res = res | ((buff[2*stride] > threshold) << 5); + res = res | ((buff[3*stride] > threshold) << 4); + res = res | ((buff[4*stride] > threshold) << 3); + res = res | ((buff[5*stride] > threshold) << 2); + res = res | ((buff[6*stride] > threshold) << 1); + res = res | (buff[7*stride] > threshold); + return res; + } + + template<> + uint8_t pack(const bool* buff, int stride, const bool &threshold){ + //ignore threshold + uint8_t res; + res = buff[0] << 7; + res = res | (buff[1*stride] << 6); + res = res | (buff[2*stride] << 5); + res = res | (buff[3*stride] << 4); + res = res | (buff[4*stride] << 3); + res = res | (buff[5*stride] << 2); + res = res | (buff[6*stride] << 1); + res = res | buff[7*stride] ; + return res; + } + + + template + void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) { + + auto rank =input.rankOf(); + X threshold = thresholdScalar.e(0); + auto buff = input.bufferAsT(); + uint8_t *outBuff = output.bufferAsT(); + if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){ + FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { + //nd4j_printf("s: %i e: %i \n", (int)start,(int)stop); + auto outBuffPart = outBuff + start; + auto buffPart = buff + start*8; + auto len = stop-start; + //run + for(auto i=0;i < len; i++){ + outBuffPart[i] = pack(&(buffPart[8*i]), threshold); + } + }; + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + + } + else{ + + auto inShapes = input.shapeOf(); + auto outShapes = output.shapeOf(); + auto inStrides = input.stridesOf(); + auto outStrides = output.stridesOf(); + + if(rank == 1){ + auto inLastStride = inStrides[rank-1]; + auto outLastStride = outStrides[rank-1]; + FUNC_1D func = [buff, outBuff, inLastStride, outLastStride, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { + //nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop); + auto buffPart = buff + start*8*inLastStride; + auto outBuffPart = outBuff + start* outLastStride; + auto len = stop-start; + //run + for(auto i=0;i < len; i++){ + *outBuffPart = pack(buffPart, inLastStride, threshold); + buffPart += 8*inLastStride; + outBuffPart += outLastStride; + + } + }; + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + }else{ + //if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8} + //therefore we can split input shape {n1, n2, n3 , 8} and correct its stride + //as we do not need last shape info. lets just extend and correct its stride + Nd4jLong extendedStrides[MAX_RANK]; + for(int i=0;i void { + Nd4jLong coords[MAX_RANK] = {}; + Nd4jLong* ptr_coords = (Nd4jLong*)&coords; + //nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop); + auto len = (stop-start); + // its extended as {rank+1} so extendedStrides[rank] is valid + auto innermostStride = extendedStrides[rank]; + sd::index2coords_C(start, rank, outShapes, ptr_coords); + //here last dimension will not be in coords. this way output shape and input shapes are equal + auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank); + for(auto k=0; k < len; k++){ + auto buffPart = &(buff[offset.first]); + auto outBuffPart = &(outBuff[offset.second]); + *outBuffPart = pack(buffPart, innermostStride, threshold); + offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank); + } + }; + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + } + + } + } + + ///////////////////////////////////////////////////////////// + void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) { + + BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES); + } + + BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES); + + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compare_and_bitpack.cu b/libnd4j/include/ops/declarable/helpers/cuda/compare_and_bitpack.cu new file mode 100644 index 000000000..d24233b3d --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/compare_and_bitpack.cu @@ -0,0 +1,191 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + // + // @author AbdelRauf + // + +#include +#include +#include +#include +#include +#include +#include +namespace sd { +namespace ops { +namespace helpers { + + template + _CUDA_HD uint8_t pack(const X* buff, const X& threshold){ + uint8_t res; + res = (buff[0] > threshold) << 7; + res = res | ((buff[1] > threshold) << 6); + res = res | ((buff[2] > threshold) << 5); + res = res | ((buff[3] > threshold) << 4); + res = res | ((buff[4] > threshold) << 3); + res = res | ((buff[5] > threshold) << 2); + res = res | ((buff[6] > threshold) << 1); + res = res | (buff[7] > threshold); + return res; + } + + template<> + _CUDA_HD uint8_t pack(const bool* buff, const bool &threshold){ + //ignore threshold + uint8_t res; + res = buff[0] << 7; + res = res | (buff[1] << 6); + res = res | (buff[2] << 5); + res = res | (buff[3] << 4); + res = res | (buff[4] << 3); + res = res | (buff[5] << 2); + res = res | (buff[6] << 1); + res = res | buff[7] ; + return res; + } + + template + _CUDA_HD uint8_t pack(const X* buff, int stride, const X& threshold){ + uint8_t res; + res = (buff[0] > threshold) << 7; + res = res | ((buff[1*stride] > threshold) << 6); + res = res | ((buff[2*stride] > threshold) << 5); + res = res | ((buff[3*stride] > threshold) << 4); + res = res | ((buff[4*stride] > threshold) << 3); + res = res | ((buff[5*stride] > threshold) << 2); + res = res | ((buff[6*stride] > threshold) << 1); + res = res | (buff[7*stride] > threshold); + return res; + } + + template<> + _CUDA_HD uint8_t pack(const bool* buff, int stride, const bool &threshold){ + //ignore threshold + uint8_t res; + res = buff[0] << 7; + res = res | (buff[1*stride] << 6); + res = res | (buff[2*stride] << 5); + res = res | (buff[3*stride] << 4); + res = res | (buff[4*stride] << 3); + res = res | (buff[5*stride] << 2); + res = res | (buff[6*stride] << 1); + res = res | buff[7*stride] ; + return res; + } +/////////////////////////////////////////////////////////////////// +template +static void _CUDA_G cmpBitpack(const void* vx, void* vz, int rank, int len, const Nd4jLong *xStridesExtended, const Nd4jLong *outPutShapeInfo, T threshold) { + + const T* x = reinterpret_cast(vx); + uint8_t* z = reinterpret_cast(vz); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto shapes = shape::shapeOf(outPutShapeInfo); + auto zStrides = shape::stride(outPutShapeInfo); + Nd4jLong coords[MAX_RANK] = {}; + Nd4jLong* ptr_coords = (Nd4jLong*)&coords; + // its extended as {rank+1} so xStridesExtended[rank] is valid + auto inLastStride = xStridesExtended[rank]; + + for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){ + sd::index2coords_C(k, rank, shapes, ptr_coords); + auto offset = sd::offset_from_coords(xStridesExtended, zStrides, ptr_coords, rank); + auto buffPart = &(x[offset.first]); + auto outBuffPart = &(z[offset.second]); + *outBuffPart = pack(buffPart, inLastStride, threshold); + } +} + +template +static void _CUDA_G cmpBitpackEws(const void* vx, void* vz, int len, const Nd4jLong xStride, const Nd4jLong yStride, T threshold) { + + const T* x = reinterpret_cast(vx); + uint8_t* z = reinterpret_cast(vz); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if(xStride==1){ + for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){ + auto buffPart = &(x[k*8]); + auto outBuffPart = &(z[k*yStride]); + *outBuffPart = pack(buffPart, threshold); + } + }else{ + for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){ + auto buffPart = &(x[k*8*xStride]); + auto outBuffPart = &(z[k*yStride]); + *outBuffPart = pack(buffPart, xStride, threshold); + } + } +} + +/////////////////////////////////////////////////////////////////// +template +static _CUDA_H void cmpBitpackCudaLauncher(sd::graph::Context& block, const NDArray& input, const NDArray& thresholdScalar, NDArray& output) { + T threshold = thresholdScalar.e(0); + + + auto inStrides = input.stridesOf(); + auto rank = output.rankOf(); + + //threadblock size + const int threadsPerBlock = MAX_NUM_THREADS / 2; + //grid size + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + auto stream = block.launchContext()->getCudaStream(); + //nd4j_printf("n %i g %i th %i \n", output.lengthOf(), blocksPerGrid, threadsPerBlock); + PointersManager manager(block.launchContext(), "compare_and_bitpack"); + NDArray::prepareSpecialUse({&output}, {&input}); + if(input.ews()>0 && output.ews()>0 && input.ordering()=='c' && output.ordering()=='c'){ + cmpBitpackEws<<>>(input.specialBuffer(), output.specialBuffer(), output.lengthOf(), inStrides[rank-1], output.stridesOf()[rank-1] , threshold); + }else{ + //if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8} + //therefore we can split input shape {n1, n2, n3 , 8} and correct its stride + //as we do not need last shape info. lets just extend and correct its stride + Nd4jLong extendedStrides[MAX_RANK]; + for(int i=0;i(manager.replicatePointer(extendedStrides, strideSize)); + cmpBitpack<<>>(input.specialBuffer(), output.specialBuffer(), rank, output.lengthOf(), extendedStridesDevPtr, output.specialShapeInfo(), threshold); + } + + NDArray::registerSpecialUse({&output}, {&input}); + manager.synchronize(); + +} + + +void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) { + + BUILD_SINGLE_SELECTOR(input.dataType(), cmpBitpackCudaLauncher, (block, input, threshold, output), LIBND4J_TYPES); +} + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/transforms.h b/libnd4j/include/ops/declarable/helpers/transforms.h index bcb0f8ee5..fdebbf253 100644 --- a/libnd4j/include/ops/declarable/helpers/transforms.h +++ b/libnd4j/include/ops/declarable/helpers/transforms.h @@ -26,7 +26,7 @@ #include #include #include - +#include namespace sd { namespace ops { namespace helpers { @@ -84,6 +84,8 @@ namespace helpers { void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps); void split(sd::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis); + + void compareAndBitpack(graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output); } } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 9a59a0bfe..91ebb5ba6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -27,7 +27,7 @@ #include #include #include - +#include using namespace sd; @@ -1757,6 +1757,208 @@ TEST_F(DeclarableOpsTests9, prelu_test14) { } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { + + auto x = NDArrayFactory::create('c', {2, 3, 16}, { + 0.865595f, 0.381197f, 0.911656f, 0.256752f, 0.084921f, 0.070434f, 0.469923f, 0.269935f, 0.510656f, 0.949777f, 0.926772f, 0.622540f, 0.688253f, 0.164974f, + 0.068558f, 0.031173f, 0.910035f, 0.219362f, 0.731336f, 0.135392f, 0.449875f, 0.020135f, 0.891820f, 0.907567f, 0.114376f, 0.652253f, 0.892939f, 0.698095f, + 0.423831f, 0.971155f, 0.968733f, 0.194465f, 0.852475f, 0.642962f, 0.417665f, 0.768379f, 0.753035f, 0.738440f, 0.046251f, 0.659487f, 0.486230f, 0.246724f, + 0.276700f, 0.103631f, 0.843105f, 0.562587f, 0.784459f, 0.109871f, 0.455828f, 0.129641f, 0.002471f, 0.148281f, 0.976162f, 0.603573f, 0.752530f, 0.249840f, + 0.723716f, 0.658430f, 0.661057f, 0.328042f, 0.338351f, 0.903157f, 0.485580f, 0.405103f, 0.335052f, 0.509858f, 0.764852f, 0.764527f, 0.382572f, 0.962121f, + 0.296145f, 0.602766f, 0.169683f, 0.750371f, 0.993936f, 0.914704f, 0.199342f, 0.858098f, 0.617198f, 0.219334f, 0.167574f, 0.305204f, 0.960773f, 0.537944f, + 0.245441f, 0.787276f, 0.968920f, 0.980918f, 0.615237f, 0.355165f, 0.480441f, 0.304282f, 0.961229f, 0.639195f, 0.017776f, 0.836153f + }); + auto threshold = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141}); + + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test2) { + + auto x = NDArrayFactory::create('c', {2, 3, 16}, { + true, false, true, false, false, false, false, false, true, + true, true, true, true, false, false, false, true, false, + true, false, false, false, true, true, false, true, true, + true, false, true, true, false, true, true, false, true, + true, true, false, true, false, false, false, false, true, + true, true, false, false, false, false, false, true, true, + true, false, true, true, true, false, false, true, false, + false, false, true, true, true, false, true, false, true, + false, true, true, true, false, true, true, false, false, + false, true, true, false, true, true, true, true, false, + false, false, true, true, false, true + }); + //threshold is ignored here ,actually + auto threshold = NDArrayFactory::create(true); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141}); + + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test3) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 16}); + auto threshold = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {2, 0, 3, 2}); + + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output->printShapeInfo("output"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test4) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 13}); + auto threshold = NDArrayFactory::create(0.5f); + sd::ops::compare_and_bitpack op; + + ASSERT_THROW(op.evaluate({&x, &threshold}, {}, {}, {}), std::invalid_argument); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test5) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 13}); + auto threshold = NDArrayFactory::create(0.5f); + auto out = NDArrayFactory::create('c', {2, 0, 3, 1}); + sd::ops::compare_and_bitpack op; + + ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::invalid_argument); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test6) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 8}); + auto threshold = NDArrayFactory::create(0.5f); + auto out = NDArrayFactory::create('c', {2, 0, 3, 2}); + sd::ops::compare_and_bitpack op; + //shape mismatch throws runtime error + ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::runtime_error); + +} + +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test7) { + constexpr int pp = 32*32*16; + constexpr int s1 = 3; + constexpr int t1 = 8; + std::vector shape1 = {pp}; + std::vector strides1 = {s1}; + std::vector shape2 = {pp/8}; + std::vector strides2 = {t1}; + ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, s1); + ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, t1); + auto x = NDArrayFactory::create(desc1); + auto output = NDArrayFactory::create(desc2); + auto exp = NDArrayFactory::create(desc2); + auto threshold = NDArrayFactory::create(true); + auto buff = x.bufferAsT(); + uint8_t *expBuff = exp.bufferAsT(); + //generate test + for(int l=0;l shape1 = {pp,pp,pp}; + std::vector strides1 = {s3 , s2 , s1}; + std::vector shape2 = {pp,pp,pp/8}; + std::vector strides2 = {t3 , t2 , t1}; + ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, 0); + ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, 0); + auto x = NDArrayFactory::create(desc1); + auto output = NDArrayFactory::create(desc2); + auto exp = NDArrayFactory::create(desc2); + auto threshold = NDArrayFactory::create(true); + auto buff = x.bufferAsT(); + uint8_t *expBuff = exp.bufferAsT(); + //generate test + for(int i=0;i('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto threshold = NDArrayFactory::create(2.0); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - - sd::ops::compare_and_bitpack op; - - auto result = op.evaluate({&x, &threshold}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); -// output->printIndexedBuffer("Packed to uint8"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - -} - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {