Compare_and_bitpack: It was reimplemented. now the last dimension should be divisible by 8

Signed-off-by: AbdelRauf <rauf@konduit.ai>
master
AbdelRauf 2021-02-28 19:19:59 +01:00
parent b66454d593
commit fe22bd5726
6 changed files with 603 additions and 37 deletions

View File

@ -23,6 +23,7 @@
#include <ops/declarable/generic/helpers/BroadcastHelper.h> #include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <ops/declarable/headers/parity_ops.h> #include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/headers/datatypes.h> #include <ops/declarable/headers/datatypes.h>
#include <ops/declarable/helpers/transforms.h>
#include <array/NDArrayFactory.h> #include <array/NDArrayFactory.h>
namespace sd { namespace sd {
@ -31,17 +32,9 @@ namespace sd {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext());
BROADCAST_CHECK_EMPTY(x, y, (&z0));
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); sd::ops::helpers::compareAndBitpack(block, *x, *y, *z);
bitcast res; return Status::OK();
auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false);
if (tZ != &z0) {
delete tZ;
}
return status;
} }
DECLARE_TYPES(compare_and_bitpack) { DECLARE_TYPES(compare_and_bitpack) {
@ -53,9 +46,15 @@ namespace sd {
DECLARE_SHAPE_FN(compare_and_bitpack) { DECLARE_SHAPE_FN(compare_and_bitpack) {
auto inShape = inputShape->at(0); auto inShape = inputShape->at(0);
auto shapes = shape::shapeOf(inShape);
const int rank = shape::rank(inShape);
REQUIRE_TRUE(!shape::isScalar(inShape), 0, "Input should not be a scalar");
std::vector<Nd4jLong> shapeDims {shapes, shapes + rank};
REQUIRE_TRUE(shapeDims[rank-1] % 8 ==0 , 0, "Last dimension of the input (which is %i) should be divisible by 8 ", shapeDims[rank-1]);
shapeDims[rank-1] = shapeDims[rank-1] / 8 ;
DataType newType = DataType::UINT8; DataType newType = DataType::UINT8;
auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(newType, shape::order(inShape), shapeDims);
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType))); return SHAPELIST(outputShape);
} }
} }

View File

@ -1908,15 +1908,15 @@ namespace sd {
#endif #endif
/** /**
* compare_and_bitpack - compare with greater and pack result with uint8 * compare_and_bitpack - Compare values of input to threshold and pack resulting bits into a uint8
* *
* input params: * input params:
* 0 - NDArray (input) * 0 - NDArray (input). Note: last dimension should be divisibly by 8
* 1 - 0D Tensor - threshold * 1 - 0D Tensor - threshold to compare against. Note: when input and threshold is bool type, the threshold is ignored
* *
* *
* output: * output:
* 0 - NDArray with the same shape as input and type uint8 * 0 - NDArray with the shape as {input.dim0,...input.dimLast/8} and type uint8
*/ */
#if NOT_EXCLUDED(OP_compare_and_bitpack) #if NOT_EXCLUDED(OP_compare_and_bitpack)
DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);

View File

@ -0,0 +1,191 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
//
// @author AbdelRauf
//
#include <type_traits>
#include <cmath>
#include <stdexcept>
#include <memory>
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <helpers/LoopsCoordsHelper.h>
#include <ops/declarable/helpers/transforms.h>
#include <helpers/LoopsCoordsHelper.h>
namespace sd {
namespace ops {
namespace helpers {
template<typename X>
uint8_t pack(const X* buff, const X& threshold){
uint8_t res;
res = (buff[0] > threshold) << 7;
res = res | ((buff[1] > threshold) << 6);
res = res | ((buff[2] > threshold) << 5);
res = res | ((buff[3] > threshold) << 4);
res = res | ((buff[4] > threshold) << 3);
res = res | ((buff[5] > threshold) << 2);
res = res | ((buff[6] > threshold) << 1);
res = res | (buff[7] > threshold);
return res;
}
template<>
uint8_t pack<bool>(const bool* buff, const bool &threshold){
//ignore threshold
uint8_t res;
res = buff[0] << 7;
res = res | (buff[1] << 6);
res = res | (buff[2] << 5);
res = res | (buff[3] << 4);
res = res | (buff[4] << 3);
res = res | (buff[5] << 2);
res = res | (buff[6] << 1);
res = res | buff[7] ;
return res;
}
template<typename X>
uint8_t pack(const X* buff, int stride, const X& threshold){
uint8_t res;
res = (buff[0] > threshold) << 7;
res = res | ((buff[1*stride] > threshold) << 6);
res = res | ((buff[2*stride] > threshold) << 5);
res = res | ((buff[3*stride] > threshold) << 4);
res = res | ((buff[4*stride] > threshold) << 3);
res = res | ((buff[5*stride] > threshold) << 2);
res = res | ((buff[6*stride] > threshold) << 1);
res = res | (buff[7*stride] > threshold);
return res;
}
template<>
uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
//ignore threshold
uint8_t res;
res = buff[0] << 7;
res = res | (buff[1*stride] << 6);
res = res | (buff[2*stride] << 5);
res = res | (buff[3*stride] << 4);
res = res | (buff[4*stride] << 3);
res = res | (buff[5*stride] << 2);
res = res | (buff[6*stride] << 1);
res = res | buff[7*stride] ;
return res;
}
template <typename X>
void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
auto rank =input.rankOf();
X threshold = thresholdScalar.e<X>(0);
auto buff = input.bufferAsT<X>();
uint8_t *outBuff = output.bufferAsT<uint8_t>();
if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){
FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
//nd4j_printf("s: %i e: %i \n", (int)start,(int)stop);
auto outBuffPart = outBuff + start;
auto buffPart = buff + start*8;
auto len = stop-start;
//run
for(auto i=0;i < len; i++){
outBuffPart[i] = pack<X>(&(buffPart[8*i]), threshold);
}
};
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
}
else{
auto inShapes = input.shapeOf();
auto outShapes = output.shapeOf();
auto inStrides = input.stridesOf();
auto outStrides = output.stridesOf();
if(rank == 1){
auto inLastStride = inStrides[rank-1];
auto outLastStride = outStrides[rank-1];
FUNC_1D func = [buff, outBuff, inLastStride, outLastStride, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
//nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop);
auto buffPart = buff + start*8*inLastStride;
auto outBuffPart = outBuff + start* outLastStride;
auto len = stop-start;
//run
for(auto i=0;i < len; i++){
*outBuffPart = pack<X>(buffPart, inLastStride, threshold);
buffPart += 8*inLastStride;
outBuffPart += outLastStride;
}
};
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
}else{
//if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
//therefore we can split input shape {n1, n2, n3 , 8} and correct its stride
//as we do not need last shape info. lets just extend and correct its stride
Nd4jLong extendedStrides[MAX_RANK];
for(int i=0;i<rank; i++){
extendedStrides[i] = inStrides[i];
}
//lets correct new stride
extendedStrides[rank-1] = 8*inStrides[rank-1];
extendedStrides[rank] = inStrides[rank-1];
//general case. its slow. we can improve it for special case later
//generic case that could be further imrpoved. for now its slow
FUNC_1D func = [rank, buff, outBuff, outShapes, extendedStrides, outStrides, threshold]
(uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
Nd4jLong coords[MAX_RANK] = {};
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
//nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop);
auto len = (stop-start);
// its extended as {rank+1} so extendedStrides[rank] is valid
auto innermostStride = extendedStrides[rank];
sd::index2coords_C(start, rank, outShapes, ptr_coords);
//here last dimension will not be in coords. this way output shape and input shapes are equal
auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank);
for(auto k=0; k < len; k++){
auto buffPart = &(buff[offset.first]);
auto outBuffPart = &(outBuff[offset.second]);
*outBuffPart = pack<X>(buffPart, innermostStride, threshold);
offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank);
}
};
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
}
}
}
/////////////////////////////////////////////////////////////
void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) {
BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES);
}
}
}

View File

@ -0,0 +1,191 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
//
// @author AbdelRauf
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/imagesHelpers.h>
#include <helpers/ConstantTadHelper.h>
#include <ops/declarable/helpers/adjust_hue.h>
#include <helpers/PointersManager.h>
#include <ops/declarable/helpers/transforms.h>
#include <helpers/LoopsCoordsHelper.h>
namespace sd {
namespace ops {
namespace helpers {
template<typename X>
_CUDA_HD uint8_t pack(const X* buff, const X& threshold){
uint8_t res;
res = (buff[0] > threshold) << 7;
res = res | ((buff[1] > threshold) << 6);
res = res | ((buff[2] > threshold) << 5);
res = res | ((buff[3] > threshold) << 4);
res = res | ((buff[4] > threshold) << 3);
res = res | ((buff[5] > threshold) << 2);
res = res | ((buff[6] > threshold) << 1);
res = res | (buff[7] > threshold);
return res;
}
template<>
_CUDA_HD uint8_t pack<bool>(const bool* buff, const bool &threshold){
//ignore threshold
uint8_t res;
res = buff[0] << 7;
res = res | (buff[1] << 6);
res = res | (buff[2] << 5);
res = res | (buff[3] << 4);
res = res | (buff[4] << 3);
res = res | (buff[5] << 2);
res = res | (buff[6] << 1);
res = res | buff[7] ;
return res;
}
template<typename X>
_CUDA_HD uint8_t pack(const X* buff, int stride, const X& threshold){
uint8_t res;
res = (buff[0] > threshold) << 7;
res = res | ((buff[1*stride] > threshold) << 6);
res = res | ((buff[2*stride] > threshold) << 5);
res = res | ((buff[3*stride] > threshold) << 4);
res = res | ((buff[4*stride] > threshold) << 3);
res = res | ((buff[5*stride] > threshold) << 2);
res = res | ((buff[6*stride] > threshold) << 1);
res = res | (buff[7*stride] > threshold);
return res;
}
template<>
_CUDA_HD uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
//ignore threshold
uint8_t res;
res = buff[0] << 7;
res = res | (buff[1*stride] << 6);
res = res | (buff[2*stride] << 5);
res = res | (buff[3*stride] << 4);
res = res | (buff[4*stride] << 3);
res = res | (buff[5*stride] << 2);
res = res | (buff[6*stride] << 1);
res = res | buff[7*stride] ;
return res;
}
///////////////////////////////////////////////////////////////////
template <typename T>
static void _CUDA_G cmpBitpack(const void* vx, void* vz, int rank, int len, const Nd4jLong *xStridesExtended, const Nd4jLong *outPutShapeInfo, T threshold) {
const T* x = reinterpret_cast<const T*>(vx);
uint8_t* z = reinterpret_cast<uint8_t*>(vz);
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto shapes = shape::shapeOf(outPutShapeInfo);
auto zStrides = shape::stride(outPutShapeInfo);
Nd4jLong coords[MAX_RANK] = {};
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
// its extended as {rank+1} so xStridesExtended[rank] is valid
auto inLastStride = xStridesExtended[rank];
for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
sd::index2coords_C(k, rank, shapes, ptr_coords);
auto offset = sd::offset_from_coords(xStridesExtended, zStrides, ptr_coords, rank);
auto buffPart = &(x[offset.first]);
auto outBuffPart = &(z[offset.second]);
*outBuffPart = pack<T>(buffPart, inLastStride, threshold);
}
}
template <typename T>
static void _CUDA_G cmpBitpackEws(const void* vx, void* vz, int len, const Nd4jLong xStride, const Nd4jLong yStride, T threshold) {
const T* x = reinterpret_cast<const T*>(vx);
uint8_t* z = reinterpret_cast<uint8_t*>(vz);
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if(xStride==1){
for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
auto buffPart = &(x[k*8]);
auto outBuffPart = &(z[k*yStride]);
*outBuffPart = pack<T>(buffPart, threshold);
}
}else{
for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
auto buffPart = &(x[k*8*xStride]);
auto outBuffPart = &(z[k*yStride]);
*outBuffPart = pack<T>(buffPart, xStride, threshold);
}
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static _CUDA_H void cmpBitpackCudaLauncher(sd::graph::Context& block, const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
T threshold = thresholdScalar.e<T>(0);
auto inStrides = input.stridesOf();
auto rank = output.rankOf();
//threadblock size
const int threadsPerBlock = MAX_NUM_THREADS / 2;
//grid size
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
auto stream = block.launchContext()->getCudaStream();
//nd4j_printf("n %i g %i th %i \n", output.lengthOf(), blocksPerGrid, threadsPerBlock);
PointersManager manager(block.launchContext(), "compare_and_bitpack");
NDArray::prepareSpecialUse({&output}, {&input});
if(input.ews()>0 && output.ews()>0 && input.ordering()=='c' && output.ordering()=='c'){
cmpBitpackEws<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), output.lengthOf(), inStrides[rank-1], output.stridesOf()[rank-1] , threshold);
}else{
//if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
//therefore we can split input shape {n1, n2, n3 , 8} and correct its stride
//as we do not need last shape info. lets just extend and correct its stride
Nd4jLong extendedStrides[MAX_RANK];
for(int i=0;i<rank; i++){
extendedStrides[i] = inStrides[i];
}
//lets correct new stride
extendedStrides[rank-1] = 8*inStrides[rank-1];
extendedStrides[rank] = inStrides[rank-1];
auto strideSize = (rank+1)*sizeof(Nd4jLong);
Nd4jLong* extendedStridesDevPtr = reinterpret_cast<Nd4jLong*>(manager.replicatePointer(extendedStrides, strideSize));
cmpBitpack<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), rank, output.lengthOf(), extendedStridesDevPtr, output.specialShapeInfo(), threshold);
}
NDArray::registerSpecialUse({&output}, {&input});
manager.synchronize();
}
void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) {
BUILD_SINGLE_SELECTOR(input.dataType(), cmpBitpackCudaLauncher, (block, input, threshold, output), LIBND4J_TYPES);
}
}
}
}

View File

@ -26,7 +26,7 @@
#include <ops/declarable/helpers/helpers.h> #include <ops/declarable/helpers/helpers.h>
#include <helpers/helper_random.h> #include <helpers/helper_random.h>
#include <graph/RandomGenerator.h> #include <graph/RandomGenerator.h>
#include <graph/Context.h>
namespace sd { namespace sd {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
@ -84,6 +84,8 @@ namespace helpers {
void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps); void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis); void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis);
void compareAndBitpack(graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output);
} }
} }
} }

View File

@ -27,7 +27,7 @@
#include <ops/ops.h> #include <ops/ops.h>
#include <helpers/GradCheck.h> #include <helpers/GradCheck.h>
#include <loops/random.h> #include <loops/random.h>
#include <array/DataType.h>
using namespace sd; using namespace sd;
@ -1757,6 +1757,208 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
auto x = NDArrayFactory::create<float>('c', {2, 3, 16}, {
0.865595f, 0.381197f, 0.911656f, 0.256752f, 0.084921f, 0.070434f, 0.469923f, 0.269935f, 0.510656f, 0.949777f, 0.926772f, 0.622540f, 0.688253f, 0.164974f,
0.068558f, 0.031173f, 0.910035f, 0.219362f, 0.731336f, 0.135392f, 0.449875f, 0.020135f, 0.891820f, 0.907567f, 0.114376f, 0.652253f, 0.892939f, 0.698095f,
0.423831f, 0.971155f, 0.968733f, 0.194465f, 0.852475f, 0.642962f, 0.417665f, 0.768379f, 0.753035f, 0.738440f, 0.046251f, 0.659487f, 0.486230f, 0.246724f,
0.276700f, 0.103631f, 0.843105f, 0.562587f, 0.784459f, 0.109871f, 0.455828f, 0.129641f, 0.002471f, 0.148281f, 0.976162f, 0.603573f, 0.752530f, 0.249840f,
0.723716f, 0.658430f, 0.661057f, 0.328042f, 0.338351f, 0.903157f, 0.485580f, 0.405103f, 0.335052f, 0.509858f, 0.764852f, 0.764527f, 0.382572f, 0.962121f,
0.296145f, 0.602766f, 0.169683f, 0.750371f, 0.993936f, 0.914704f, 0.199342f, 0.858098f, 0.617198f, 0.219334f, 0.167574f, 0.305204f, 0.960773f, 0.537944f,
0.245441f, 0.787276f, 0.968920f, 0.980918f, 0.615237f, 0.355165f, 0.480441f, 0.304282f, 0.961229f, 0.639195f, 0.017776f, 0.836153f
});
auto threshold = NDArrayFactory::create<float>(0.5f);
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141});
sd::ops::compare_and_bitpack op;
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
auto output = result.at(0);
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test2) {
auto x = NDArrayFactory::create<bool>('c', {2, 3, 16}, {
true, false, true, false, false, false, false, false, true,
true, true, true, true, false, false, false, true, false,
true, false, false, false, true, true, false, true, true,
true, false, true, true, false, true, true, false, true,
true, true, false, true, false, false, false, false, true,
true, true, false, false, false, false, false, true, true,
true, false, true, true, true, false, false, true, false,
false, false, true, true, true, false, true, false, true,
false, true, true, true, false, true, true, false, false,
false, true, true, false, true, true, true, true, false,
false, false, true, true, false, true
});
//threshold is ignored here ,actually
auto threshold = NDArrayFactory::create<bool>(true);
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141});
sd::ops::compare_and_bitpack op;
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
auto output = result.at(0);
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test3) {
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 16});
auto threshold = NDArrayFactory::create<float>(0.5f);
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2});
sd::ops::compare_and_bitpack op;
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
auto output = result.at(0);
ASSERT_EQ(ND4J_STATUS_OK, result.status());
output->printShapeInfo("output");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test4) {
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13});
auto threshold = NDArrayFactory::create<float>(0.5f);
sd::ops::compare_and_bitpack op;
ASSERT_THROW(op.evaluate({&x, &threshold}, {}, {}, {}), std::invalid_argument);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test5) {
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13});
auto threshold = NDArrayFactory::create<float>(0.5f);
auto out = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 1});
sd::ops::compare_and_bitpack op;
ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::invalid_argument);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test6) {
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 8});
auto threshold = NDArrayFactory::create<float>(0.5f);
auto out = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2});
sd::ops::compare_and_bitpack op;
//shape mismatch throws runtime error
ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::runtime_error);
}
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test7) {
constexpr int pp = 32*32*16;
constexpr int s1 = 3;
constexpr int t1 = 8;
std::vector<Nd4jLong> shape1 = {pp};
std::vector<Nd4jLong> strides1 = {s1};
std::vector<Nd4jLong> shape2 = {pp/8};
std::vector<Nd4jLong> strides2 = {t1};
ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, s1);
ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, t1);
auto x = NDArrayFactory::create(desc1);
auto output = NDArrayFactory::create(desc2);
auto exp = NDArrayFactory::create(desc2);
auto threshold = NDArrayFactory::create<bool>(true);
auto buff = x.bufferAsT<bool>();
uint8_t *expBuff = exp.bufferAsT<uint8_t>();
//generate test
for(int l=0;l<pp; l+=8){
uint8_t test = rand() % 255;
expBuff[l/8*t1] = test;
auto buffP = &(buff[l*s1]);
buffP[0] = test & (1<<7);
buffP[1*s1] = test & (1<<6);
buffP[2*s1] = test & (1<<5);
buffP[3*s1] = test & (1<<4);
buffP[4*s1] = test & (1<<3);
buffP[5*s1] = test & (1<<2);
buffP[6*s1] = test & (1<<1);
buffP[7*s1] = test & 1;
}
//explicit sync to device
x.tickWriteHost();
exp.tickWriteHost();
x.syncToDevice();
exp.syncToDevice();
sd::ops::compare_and_bitpack op;
auto result = op.execute({&x, &threshold}, {&output}, {}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_TRUE(exp.isSameShape(&output));
ASSERT_TRUE(exp.equalsTo(&output));
}
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test8) {
constexpr int pp = 32;
constexpr int s1 = 2;
constexpr int s2 = (s1*pp) + 3;
constexpr int s3 = (s2*pp) + 4;
constexpr int t1 = 2;
constexpr int t2 = (t1*pp/8) + 3;
constexpr int t3 = (t2*pp) + 4;
std::vector<Nd4jLong> shape1 = {pp,pp,pp};
std::vector<Nd4jLong> strides1 = {s3 , s2 , s1};
std::vector<Nd4jLong> shape2 = {pp,pp,pp/8};
std::vector<Nd4jLong> strides2 = {t3 , t2 , t1};
ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, 0);
ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, 0);
auto x = NDArrayFactory::create(desc1);
auto output = NDArrayFactory::create(desc2);
auto exp = NDArrayFactory::create(desc2);
auto threshold = NDArrayFactory::create<bool>(true);
auto buff = x.bufferAsT<bool>();
uint8_t *expBuff = exp.bufferAsT<uint8_t>();
//generate test
for(int i=0;i<pp;i++){
for(int j=0;j<pp;j++){
for(int l=0;l<pp; l+=8){
uint8_t test = rand() % 255;
expBuff[l/8*t1 + j*t2 + i *t3] = test;
auto buffP = &(buff[j*s2 + i *s3 + l*s1]);
buffP[0] = test & (1<<7);
buffP[1*s1] = test & (1<<6);
buffP[2*s1] = test & (1<<5);
buffP[3*s1] = test & (1<<4);
buffP[4*s1] = test & (1<<3);
buffP[5*s1] = test & (1<<2);
buffP[6*s1] = test & (1<<1);
buffP[7*s1] = test & 1;
}
}
}
//explicit sync to device
x.tickWriteHost();
exp.tickWriteHost();
x.syncToDevice();
exp.syncToDevice();
sd::ops::compare_and_bitpack op;
auto result = op.execute({&x, &threshold}, {&output}, {}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_TRUE(exp.isSameShape(&output));
ASSERT_TRUE(exp.equalsTo(&output));
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
@ -1776,25 +1978,6 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto threshold = NDArrayFactory::create<double>(2.0);
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1});
sd::ops::compare_and_bitpack op;
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
// output->printIndexedBuffer("Packed to uint8");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {