Compare_and_bitpack: It was reimplemented. now the last dimension should be divisible by 8
Signed-off-by: AbdelRauf <rauf@konduit.ai>master
parent
b66454d593
commit
fe22bd5726
|
@ -23,6 +23,7 @@
|
||||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||||
#include <ops/declarable/headers/parity_ops.h>
|
#include <ops/declarable/headers/parity_ops.h>
|
||||||
#include <ops/declarable/headers/datatypes.h>
|
#include <ops/declarable/headers/datatypes.h>
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
#include <array/NDArrayFactory.h>
|
#include <array/NDArrayFactory.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
|
@ -31,17 +32,9 @@ namespace sd {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext());
|
|
||||||
BROADCAST_CHECK_EMPTY(x, y, (&z0));
|
|
||||||
|
|
||||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
sd::ops::helpers::compareAndBitpack(block, *x, *y, *z);
|
||||||
bitcast res;
|
return Status::OK();
|
||||||
auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false);
|
|
||||||
if (tZ != &z0) {
|
|
||||||
delete tZ;
|
|
||||||
}
|
|
||||||
|
|
||||||
return status;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(compare_and_bitpack) {
|
DECLARE_TYPES(compare_and_bitpack) {
|
||||||
|
@ -53,9 +46,15 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(compare_and_bitpack) {
|
DECLARE_SHAPE_FN(compare_and_bitpack) {
|
||||||
auto inShape = inputShape->at(0);
|
auto inShape = inputShape->at(0);
|
||||||
|
auto shapes = shape::shapeOf(inShape);
|
||||||
|
const int rank = shape::rank(inShape);
|
||||||
|
REQUIRE_TRUE(!shape::isScalar(inShape), 0, "Input should not be a scalar");
|
||||||
|
std::vector<Nd4jLong> shapeDims {shapes, shapes + rank};
|
||||||
|
REQUIRE_TRUE(shapeDims[rank-1] % 8 ==0 , 0, "Last dimension of the input (which is %i) should be divisible by 8 ", shapeDims[rank-1]);
|
||||||
|
shapeDims[rank-1] = shapeDims[rank-1] / 8 ;
|
||||||
DataType newType = DataType::UINT8;
|
DataType newType = DataType::UINT8;
|
||||||
|
auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(newType, shape::order(inShape), shapeDims);
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType)));
|
return SHAPELIST(outputShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1908,15 +1908,15 @@ namespace sd {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* compare_and_bitpack - compare with greater and pack result with uint8
|
* compare_and_bitpack - Compare values of input to threshold and pack resulting bits into a uint8
|
||||||
*
|
*
|
||||||
* input params:
|
* input params:
|
||||||
* 0 - NDArray (input)
|
* 0 - NDArray (input). Note: last dimension should be divisibly by 8
|
||||||
* 1 - 0D Tensor - threshold
|
* 1 - 0D Tensor - threshold to compare against. Note: when input and threshold is bool type, the threshold is ignored
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
* output:
|
* output:
|
||||||
* 0 - NDArray with the same shape as input and type uint8
|
* 0 - NDArray with the shape as {input.dim0,...input.dimLast/8} and type uint8
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_compare_and_bitpack)
|
#if NOT_EXCLUDED(OP_compare_and_bitpack)
|
||||||
DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
|
||||||
|
|
|
@ -0,0 +1,191 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
#include <cmath>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <memory>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <execution/ThreadPool.h>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
template<typename X>
|
||||||
|
uint8_t pack(const X* buff, const X& threshold){
|
||||||
|
uint8_t res;
|
||||||
|
res = (buff[0] > threshold) << 7;
|
||||||
|
res = res | ((buff[1] > threshold) << 6);
|
||||||
|
res = res | ((buff[2] > threshold) << 5);
|
||||||
|
res = res | ((buff[3] > threshold) << 4);
|
||||||
|
res = res | ((buff[4] > threshold) << 3);
|
||||||
|
res = res | ((buff[5] > threshold) << 2);
|
||||||
|
res = res | ((buff[6] > threshold) << 1);
|
||||||
|
res = res | (buff[7] > threshold);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
uint8_t pack<bool>(const bool* buff, const bool &threshold){
|
||||||
|
//ignore threshold
|
||||||
|
uint8_t res;
|
||||||
|
res = buff[0] << 7;
|
||||||
|
res = res | (buff[1] << 6);
|
||||||
|
res = res | (buff[2] << 5);
|
||||||
|
res = res | (buff[3] << 4);
|
||||||
|
res = res | (buff[4] << 3);
|
||||||
|
res = res | (buff[5] << 2);
|
||||||
|
res = res | (buff[6] << 1);
|
||||||
|
res = res | buff[7] ;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X>
|
||||||
|
uint8_t pack(const X* buff, int stride, const X& threshold){
|
||||||
|
uint8_t res;
|
||||||
|
res = (buff[0] > threshold) << 7;
|
||||||
|
res = res | ((buff[1*stride] > threshold) << 6);
|
||||||
|
res = res | ((buff[2*stride] > threshold) << 5);
|
||||||
|
res = res | ((buff[3*stride] > threshold) << 4);
|
||||||
|
res = res | ((buff[4*stride] > threshold) << 3);
|
||||||
|
res = res | ((buff[5*stride] > threshold) << 2);
|
||||||
|
res = res | ((buff[6*stride] > threshold) << 1);
|
||||||
|
res = res | (buff[7*stride] > threshold);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
|
||||||
|
//ignore threshold
|
||||||
|
uint8_t res;
|
||||||
|
res = buff[0] << 7;
|
||||||
|
res = res | (buff[1*stride] << 6);
|
||||||
|
res = res | (buff[2*stride] << 5);
|
||||||
|
res = res | (buff[3*stride] << 4);
|
||||||
|
res = res | (buff[4*stride] << 3);
|
||||||
|
res = res | (buff[5*stride] << 2);
|
||||||
|
res = res | (buff[6*stride] << 1);
|
||||||
|
res = res | buff[7*stride] ;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename X>
|
||||||
|
void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
|
||||||
|
|
||||||
|
auto rank =input.rankOf();
|
||||||
|
X threshold = thresholdScalar.e<X>(0);
|
||||||
|
auto buff = input.bufferAsT<X>();
|
||||||
|
uint8_t *outBuff = output.bufferAsT<uint8_t>();
|
||||||
|
if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){
|
||||||
|
FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||||
|
//nd4j_printf("s: %i e: %i \n", (int)start,(int)stop);
|
||||||
|
auto outBuffPart = outBuff + start;
|
||||||
|
auto buffPart = buff + start*8;
|
||||||
|
auto len = stop-start;
|
||||||
|
//run
|
||||||
|
for(auto i=0;i < len; i++){
|
||||||
|
outBuffPart[i] = pack<X>(&(buffPart[8*i]), threshold);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||||||
|
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
|
||||||
|
auto inShapes = input.shapeOf();
|
||||||
|
auto outShapes = output.shapeOf();
|
||||||
|
auto inStrides = input.stridesOf();
|
||||||
|
auto outStrides = output.stridesOf();
|
||||||
|
|
||||||
|
if(rank == 1){
|
||||||
|
auto inLastStride = inStrides[rank-1];
|
||||||
|
auto outLastStride = outStrides[rank-1];
|
||||||
|
FUNC_1D func = [buff, outBuff, inLastStride, outLastStride, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||||
|
//nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop);
|
||||||
|
auto buffPart = buff + start*8*inLastStride;
|
||||||
|
auto outBuffPart = outBuff + start* outLastStride;
|
||||||
|
auto len = stop-start;
|
||||||
|
//run
|
||||||
|
for(auto i=0;i < len; i++){
|
||||||
|
*outBuffPart = pack<X>(buffPart, inLastStride, threshold);
|
||||||
|
buffPart += 8*inLastStride;
|
||||||
|
outBuffPart += outLastStride;
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||||||
|
}else{
|
||||||
|
//if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
|
||||||
|
//therefore we can split input shape {n1, n2, n3 , 8} and correct its stride
|
||||||
|
//as we do not need last shape info. lets just extend and correct its stride
|
||||||
|
Nd4jLong extendedStrides[MAX_RANK];
|
||||||
|
for(int i=0;i<rank; i++){
|
||||||
|
extendedStrides[i] = inStrides[i];
|
||||||
|
}
|
||||||
|
//lets correct new stride
|
||||||
|
extendedStrides[rank-1] = 8*inStrides[rank-1];
|
||||||
|
extendedStrides[rank] = inStrides[rank-1];
|
||||||
|
//general case. its slow. we can improve it for special case later
|
||||||
|
//generic case that could be further imrpoved. for now its slow
|
||||||
|
FUNC_1D func = [rank, buff, outBuff, outShapes, extendedStrides, outStrides, threshold]
|
||||||
|
(uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||||
|
Nd4jLong coords[MAX_RANK] = {};
|
||||||
|
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
|
||||||
|
//nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop);
|
||||||
|
auto len = (stop-start);
|
||||||
|
// its extended as {rank+1} so extendedStrides[rank] is valid
|
||||||
|
auto innermostStride = extendedStrides[rank];
|
||||||
|
sd::index2coords_C(start, rank, outShapes, ptr_coords);
|
||||||
|
//here last dimension will not be in coords. this way output shape and input shapes are equal
|
||||||
|
auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank);
|
||||||
|
for(auto k=0; k < len; k++){
|
||||||
|
auto buffPart = &(buff[offset.first]);
|
||||||
|
auto outBuffPart = &(outBuff[offset.second]);
|
||||||
|
*outBuffPart = pack<X>(buffPart, innermostStride, threshold);
|
||||||
|
offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////
|
||||||
|
void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) {
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,191 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/imagesHelpers.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
template<typename X>
|
||||||
|
_CUDA_HD uint8_t pack(const X* buff, const X& threshold){
|
||||||
|
uint8_t res;
|
||||||
|
res = (buff[0] > threshold) << 7;
|
||||||
|
res = res | ((buff[1] > threshold) << 6);
|
||||||
|
res = res | ((buff[2] > threshold) << 5);
|
||||||
|
res = res | ((buff[3] > threshold) << 4);
|
||||||
|
res = res | ((buff[4] > threshold) << 3);
|
||||||
|
res = res | ((buff[5] > threshold) << 2);
|
||||||
|
res = res | ((buff[6] > threshold) << 1);
|
||||||
|
res = res | (buff[7] > threshold);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
_CUDA_HD uint8_t pack<bool>(const bool* buff, const bool &threshold){
|
||||||
|
//ignore threshold
|
||||||
|
uint8_t res;
|
||||||
|
res = buff[0] << 7;
|
||||||
|
res = res | (buff[1] << 6);
|
||||||
|
res = res | (buff[2] << 5);
|
||||||
|
res = res | (buff[3] << 4);
|
||||||
|
res = res | (buff[4] << 3);
|
||||||
|
res = res | (buff[5] << 2);
|
||||||
|
res = res | (buff[6] << 1);
|
||||||
|
res = res | buff[7] ;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X>
|
||||||
|
_CUDA_HD uint8_t pack(const X* buff, int stride, const X& threshold){
|
||||||
|
uint8_t res;
|
||||||
|
res = (buff[0] > threshold) << 7;
|
||||||
|
res = res | ((buff[1*stride] > threshold) << 6);
|
||||||
|
res = res | ((buff[2*stride] > threshold) << 5);
|
||||||
|
res = res | ((buff[3*stride] > threshold) << 4);
|
||||||
|
res = res | ((buff[4*stride] > threshold) << 3);
|
||||||
|
res = res | ((buff[5*stride] > threshold) << 2);
|
||||||
|
res = res | ((buff[6*stride] > threshold) << 1);
|
||||||
|
res = res | (buff[7*stride] > threshold);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
_CUDA_HD uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
|
||||||
|
//ignore threshold
|
||||||
|
uint8_t res;
|
||||||
|
res = buff[0] << 7;
|
||||||
|
res = res | (buff[1*stride] << 6);
|
||||||
|
res = res | (buff[2*stride] << 5);
|
||||||
|
res = res | (buff[3*stride] << 4);
|
||||||
|
res = res | (buff[4*stride] << 3);
|
||||||
|
res = res | (buff[5*stride] << 2);
|
||||||
|
res = res | (buff[6*stride] << 1);
|
||||||
|
res = res | buff[7*stride] ;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void _CUDA_G cmpBitpack(const void* vx, void* vz, int rank, int len, const Nd4jLong *xStridesExtended, const Nd4jLong *outPutShapeInfo, T threshold) {
|
||||||
|
|
||||||
|
const T* x = reinterpret_cast<const T*>(vx);
|
||||||
|
uint8_t* z = reinterpret_cast<uint8_t*>(vz);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto shapes = shape::shapeOf(outPutShapeInfo);
|
||||||
|
auto zStrides = shape::stride(outPutShapeInfo);
|
||||||
|
Nd4jLong coords[MAX_RANK] = {};
|
||||||
|
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
|
||||||
|
// its extended as {rank+1} so xStridesExtended[rank] is valid
|
||||||
|
auto inLastStride = xStridesExtended[rank];
|
||||||
|
|
||||||
|
for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
|
||||||
|
sd::index2coords_C(k, rank, shapes, ptr_coords);
|
||||||
|
auto offset = sd::offset_from_coords(xStridesExtended, zStrides, ptr_coords, rank);
|
||||||
|
auto buffPart = &(x[offset.first]);
|
||||||
|
auto outBuffPart = &(z[offset.second]);
|
||||||
|
*outBuffPart = pack<T>(buffPart, inLastStride, threshold);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void _CUDA_G cmpBitpackEws(const void* vx, void* vz, int len, const Nd4jLong xStride, const Nd4jLong yStride, T threshold) {
|
||||||
|
|
||||||
|
const T* x = reinterpret_cast<const T*>(vx);
|
||||||
|
uint8_t* z = reinterpret_cast<uint8_t*>(vz);
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if(xStride==1){
|
||||||
|
for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
|
||||||
|
auto buffPart = &(x[k*8]);
|
||||||
|
auto outBuffPart = &(z[k*yStride]);
|
||||||
|
*outBuffPart = pack<T>(buffPart, threshold);
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
|
||||||
|
auto buffPart = &(x[k*8*xStride]);
|
||||||
|
auto outBuffPart = &(z[k*yStride]);
|
||||||
|
*outBuffPart = pack<T>(buffPart, xStride, threshold);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static _CUDA_H void cmpBitpackCudaLauncher(sd::graph::Context& block, const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
|
||||||
|
T threshold = thresholdScalar.e<T>(0);
|
||||||
|
|
||||||
|
|
||||||
|
auto inStrides = input.stridesOf();
|
||||||
|
auto rank = output.rankOf();
|
||||||
|
|
||||||
|
//threadblock size
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
//grid size
|
||||||
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
auto stream = block.launchContext()->getCudaStream();
|
||||||
|
//nd4j_printf("n %i g %i th %i \n", output.lengthOf(), blocksPerGrid, threadsPerBlock);
|
||||||
|
PointersManager manager(block.launchContext(), "compare_and_bitpack");
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
|
if(input.ews()>0 && output.ews()>0 && input.ordering()=='c' && output.ordering()=='c'){
|
||||||
|
cmpBitpackEws<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), output.lengthOf(), inStrides[rank-1], output.stridesOf()[rank-1] , threshold);
|
||||||
|
}else{
|
||||||
|
//if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
|
||||||
|
//therefore we can split input shape {n1, n2, n3 , 8} and correct its stride
|
||||||
|
//as we do not need last shape info. lets just extend and correct its stride
|
||||||
|
Nd4jLong extendedStrides[MAX_RANK];
|
||||||
|
for(int i=0;i<rank; i++){
|
||||||
|
extendedStrides[i] = inStrides[i];
|
||||||
|
}
|
||||||
|
//lets correct new stride
|
||||||
|
extendedStrides[rank-1] = 8*inStrides[rank-1];
|
||||||
|
extendedStrides[rank] = inStrides[rank-1];
|
||||||
|
|
||||||
|
auto strideSize = (rank+1)*sizeof(Nd4jLong);
|
||||||
|
Nd4jLong* extendedStridesDevPtr = reinterpret_cast<Nd4jLong*>(manager.replicatePointer(extendedStrides, strideSize));
|
||||||
|
cmpBitpack<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), rank, output.lengthOf(), extendedStridesDevPtr, output.specialShapeInfo(), threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) {
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), cmpBitpackCudaLauncher, (block, input, threshold, output), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include <ops/declarable/helpers/helpers.h>
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
#include <helpers/helper_random.h>
|
#include <helpers/helper_random.h>
|
||||||
#include <graph/RandomGenerator.h>
|
#include <graph/RandomGenerator.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
@ -84,6 +84,8 @@ namespace helpers {
|
||||||
void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
|
void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
|
||||||
|
|
||||||
void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis);
|
void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis);
|
||||||
|
|
||||||
|
void compareAndBitpack(graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
#include <ops/ops.h>
|
#include <ops/ops.h>
|
||||||
#include <helpers/GradCheck.h>
|
#include <helpers/GradCheck.h>
|
||||||
#include <loops/random.h>
|
#include <loops/random.h>
|
||||||
|
#include <array/DataType.h>
|
||||||
|
|
||||||
using namespace sd;
|
using namespace sd;
|
||||||
|
|
||||||
|
@ -1757,6 +1757,208 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 3, 16}, {
|
||||||
|
0.865595f, 0.381197f, 0.911656f, 0.256752f, 0.084921f, 0.070434f, 0.469923f, 0.269935f, 0.510656f, 0.949777f, 0.926772f, 0.622540f, 0.688253f, 0.164974f,
|
||||||
|
0.068558f, 0.031173f, 0.910035f, 0.219362f, 0.731336f, 0.135392f, 0.449875f, 0.020135f, 0.891820f, 0.907567f, 0.114376f, 0.652253f, 0.892939f, 0.698095f,
|
||||||
|
0.423831f, 0.971155f, 0.968733f, 0.194465f, 0.852475f, 0.642962f, 0.417665f, 0.768379f, 0.753035f, 0.738440f, 0.046251f, 0.659487f, 0.486230f, 0.246724f,
|
||||||
|
0.276700f, 0.103631f, 0.843105f, 0.562587f, 0.784459f, 0.109871f, 0.455828f, 0.129641f, 0.002471f, 0.148281f, 0.976162f, 0.603573f, 0.752530f, 0.249840f,
|
||||||
|
0.723716f, 0.658430f, 0.661057f, 0.328042f, 0.338351f, 0.903157f, 0.485580f, 0.405103f, 0.335052f, 0.509858f, 0.764852f, 0.764527f, 0.382572f, 0.962121f,
|
||||||
|
0.296145f, 0.602766f, 0.169683f, 0.750371f, 0.993936f, 0.914704f, 0.199342f, 0.858098f, 0.617198f, 0.219334f, 0.167574f, 0.305204f, 0.960773f, 0.537944f,
|
||||||
|
0.245441f, 0.787276f, 0.968920f, 0.980918f, 0.615237f, 0.355165f, 0.480441f, 0.304282f, 0.961229f, 0.639195f, 0.017776f, 0.836153f
|
||||||
|
});
|
||||||
|
auto threshold = NDArrayFactory::create<float>(0.5f);
|
||||||
|
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141});
|
||||||
|
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<bool>('c', {2, 3, 16}, {
|
||||||
|
true, false, true, false, false, false, false, false, true,
|
||||||
|
true, true, true, true, false, false, false, true, false,
|
||||||
|
true, false, false, false, true, true, false, true, true,
|
||||||
|
true, false, true, true, false, true, true, false, true,
|
||||||
|
true, true, false, true, false, false, false, false, true,
|
||||||
|
true, true, false, false, false, false, false, true, true,
|
||||||
|
true, false, true, true, true, false, false, true, false,
|
||||||
|
false, false, true, true, true, false, true, false, true,
|
||||||
|
false, true, true, true, false, true, true, false, false,
|
||||||
|
false, true, true, false, true, true, true, true, false,
|
||||||
|
false, false, true, true, false, true
|
||||||
|
});
|
||||||
|
//threshold is ignored here ,actually
|
||||||
|
auto threshold = NDArrayFactory::create<bool>(true);
|
||||||
|
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141});
|
||||||
|
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 16});
|
||||||
|
auto threshold = NDArrayFactory::create<float>(0.5f);
|
||||||
|
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2});
|
||||||
|
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
output->printShapeInfo("output");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test4) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13});
|
||||||
|
auto threshold = NDArrayFactory::create<float>(0.5f);
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
|
||||||
|
ASSERT_THROW(op.evaluate({&x, &threshold}, {}, {}, {}), std::invalid_argument);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test5) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13});
|
||||||
|
auto threshold = NDArrayFactory::create<float>(0.5f);
|
||||||
|
auto out = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 1});
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
|
||||||
|
ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::invalid_argument);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test6) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 8});
|
||||||
|
auto threshold = NDArrayFactory::create<float>(0.5f);
|
||||||
|
auto out = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2});
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
//shape mismatch throws runtime error
|
||||||
|
ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::runtime_error);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test7) {
|
||||||
|
constexpr int pp = 32*32*16;
|
||||||
|
constexpr int s1 = 3;
|
||||||
|
constexpr int t1 = 8;
|
||||||
|
std::vector<Nd4jLong> shape1 = {pp};
|
||||||
|
std::vector<Nd4jLong> strides1 = {s1};
|
||||||
|
std::vector<Nd4jLong> shape2 = {pp/8};
|
||||||
|
std::vector<Nd4jLong> strides2 = {t1};
|
||||||
|
ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, s1);
|
||||||
|
ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, t1);
|
||||||
|
auto x = NDArrayFactory::create(desc1);
|
||||||
|
auto output = NDArrayFactory::create(desc2);
|
||||||
|
auto exp = NDArrayFactory::create(desc2);
|
||||||
|
auto threshold = NDArrayFactory::create<bool>(true);
|
||||||
|
auto buff = x.bufferAsT<bool>();
|
||||||
|
uint8_t *expBuff = exp.bufferAsT<uint8_t>();
|
||||||
|
//generate test
|
||||||
|
for(int l=0;l<pp; l+=8){
|
||||||
|
uint8_t test = rand() % 255;
|
||||||
|
expBuff[l/8*t1] = test;
|
||||||
|
auto buffP = &(buff[l*s1]);
|
||||||
|
buffP[0] = test & (1<<7);
|
||||||
|
buffP[1*s1] = test & (1<<6);
|
||||||
|
buffP[2*s1] = test & (1<<5);
|
||||||
|
buffP[3*s1] = test & (1<<4);
|
||||||
|
buffP[4*s1] = test & (1<<3);
|
||||||
|
buffP[5*s1] = test & (1<<2);
|
||||||
|
buffP[6*s1] = test & (1<<1);
|
||||||
|
buffP[7*s1] = test & 1;
|
||||||
|
}
|
||||||
|
//explicit sync to device
|
||||||
|
x.tickWriteHost();
|
||||||
|
exp.tickWriteHost();
|
||||||
|
x.syncToDevice();
|
||||||
|
exp.syncToDevice();
|
||||||
|
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
auto result = op.execute({&x, &threshold}, {&output}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
ASSERT_TRUE(exp.isSameShape(&output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&output));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test8) {
|
||||||
|
constexpr int pp = 32;
|
||||||
|
constexpr int s1 = 2;
|
||||||
|
constexpr int s2 = (s1*pp) + 3;
|
||||||
|
constexpr int s3 = (s2*pp) + 4;
|
||||||
|
constexpr int t1 = 2;
|
||||||
|
constexpr int t2 = (t1*pp/8) + 3;
|
||||||
|
constexpr int t3 = (t2*pp) + 4;
|
||||||
|
std::vector<Nd4jLong> shape1 = {pp,pp,pp};
|
||||||
|
std::vector<Nd4jLong> strides1 = {s3 , s2 , s1};
|
||||||
|
std::vector<Nd4jLong> shape2 = {pp,pp,pp/8};
|
||||||
|
std::vector<Nd4jLong> strides2 = {t3 , t2 , t1};
|
||||||
|
ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, 0);
|
||||||
|
ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, 0);
|
||||||
|
auto x = NDArrayFactory::create(desc1);
|
||||||
|
auto output = NDArrayFactory::create(desc2);
|
||||||
|
auto exp = NDArrayFactory::create(desc2);
|
||||||
|
auto threshold = NDArrayFactory::create<bool>(true);
|
||||||
|
auto buff = x.bufferAsT<bool>();
|
||||||
|
uint8_t *expBuff = exp.bufferAsT<uint8_t>();
|
||||||
|
//generate test
|
||||||
|
for(int i=0;i<pp;i++){
|
||||||
|
for(int j=0;j<pp;j++){
|
||||||
|
for(int l=0;l<pp; l+=8){
|
||||||
|
uint8_t test = rand() % 255;
|
||||||
|
expBuff[l/8*t1 + j*t2 + i *t3] = test;
|
||||||
|
auto buffP = &(buff[j*s2 + i *s3 + l*s1]);
|
||||||
|
buffP[0] = test & (1<<7);
|
||||||
|
buffP[1*s1] = test & (1<<6);
|
||||||
|
buffP[2*s1] = test & (1<<5);
|
||||||
|
buffP[3*s1] = test & (1<<4);
|
||||||
|
buffP[4*s1] = test & (1<<3);
|
||||||
|
buffP[5*s1] = test & (1<<2);
|
||||||
|
buffP[6*s1] = test & (1<<1);
|
||||||
|
buffP[7*s1] = test & 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//explicit sync to device
|
||||||
|
x.tickWriteHost();
|
||||||
|
exp.tickWriteHost();
|
||||||
|
x.syncToDevice();
|
||||||
|
exp.syncToDevice();
|
||||||
|
sd::ops::compare_and_bitpack op;
|
||||||
|
auto result = op.execute({&x, &threshold}, {&output}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
ASSERT_TRUE(exp.isSameShape(&output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&output));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
|
TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
|
||||||
|
|
||||||
|
@ -1776,25 +1978,6 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
|
||||||
auto threshold = NDArrayFactory::create<double>(2.0);
|
|
||||||
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
||||||
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
|
||||||
|
|
||||||
sd::ops::compare_and_bitpack op;
|
|
||||||
|
|
||||||
auto result = op.evaluate({&x, &threshold}, {}, {}, {});
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
||||||
auto output = result.at(0);
|
|
||||||
// output->printIndexedBuffer("Packed to uint8");
|
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
|
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue