cavis/libnd4j/include/ops/declarable/helpers/cpu/compare_and_bitpack.cpp

192 lines
9.4 KiB
C++

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
//
// @author AbdelRauf
//
#include <type_traits>
#include <cmath>
#include <stdexcept>
#include <memory>
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <helpers/LoopsCoordsHelper.h>
#include <ops/declarable/helpers/transforms.h>
#include <helpers/LoopsCoordsHelper.h>
namespace sd {
namespace ops {
namespace helpers {
template<typename X>
uint8_t pack(const X* buff, const X& threshold){
uint8_t res;
res = (buff[0] > threshold) << 7;
res = res | ((buff[1] > threshold) << 6);
res = res | ((buff[2] > threshold) << 5);
res = res | ((buff[3] > threshold) << 4);
res = res | ((buff[4] > threshold) << 3);
res = res | ((buff[5] > threshold) << 2);
res = res | ((buff[6] > threshold) << 1);
res = res | (buff[7] > threshold);
return res;
}
template<>
uint8_t pack<bool>(const bool* buff, const bool &threshold){
//ignore threshold
uint8_t res;
res = buff[0] << 7;
res = res | (buff[1] << 6);
res = res | (buff[2] << 5);
res = res | (buff[3] << 4);
res = res | (buff[4] << 3);
res = res | (buff[5] << 2);
res = res | (buff[6] << 1);
res = res | buff[7] ;
return res;
}
template<typename X>
uint8_t pack(const X* buff, int stride, const X& threshold){
uint8_t res;
res = (buff[0] > threshold) << 7;
res = res | ((buff[1*stride] > threshold) << 6);
res = res | ((buff[2*stride] > threshold) << 5);
res = res | ((buff[3*stride] > threshold) << 4);
res = res | ((buff[4*stride] > threshold) << 3);
res = res | ((buff[5*stride] > threshold) << 2);
res = res | ((buff[6*stride] > threshold) << 1);
res = res | (buff[7*stride] > threshold);
return res;
}
template<>
uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
//ignore threshold
uint8_t res;
res = buff[0] << 7;
res = res | (buff[1*stride] << 6);
res = res | (buff[2*stride] << 5);
res = res | (buff[3*stride] << 4);
res = res | (buff[4*stride] << 3);
res = res | (buff[5*stride] << 2);
res = res | (buff[6*stride] << 1);
res = res | buff[7*stride] ;
return res;
}
template <typename X>
void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
auto rank =input.rankOf();
X threshold = thresholdScalar.e<X>(0);
auto buff = input.bufferAsT<X>();
uint8_t *outBuff = output.bufferAsT<uint8_t>();
if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){
FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
//nd4j_printf("s: %i e: %i \n", (int)start,(int)stop);
auto outBuffPart = outBuff + start;
auto buffPart = buff + start*8;
auto len = stop-start;
//run
for(auto i=0;i < len; i++){
outBuffPart[i] = pack<X>(&(buffPart[8*i]), threshold);
}
};
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
}
else{
auto inShapes = input.shapeOf();
auto outShapes = output.shapeOf();
auto inStrides = input.stridesOf();
auto outStrides = output.stridesOf();
if(rank == 1){
auto inLastStride = inStrides[rank-1];
auto outLastStride = outStrides[rank-1];
FUNC_1D func = [buff, outBuff, inLastStride, outLastStride, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
//nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop);
auto buffPart = buff + start*8*inLastStride;
auto outBuffPart = outBuff + start* outLastStride;
auto len = stop-start;
//run
for(auto i=0;i < len; i++){
*outBuffPart = pack<X>(buffPart, inLastStride, threshold);
buffPart += 8*inLastStride;
outBuffPart += outLastStride;
}
};
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
}else{
//if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
//therefore we can split input shape {n1, n2, n3 , 8} and correct its stride
//as we do not need last shape info. lets just extend and correct its stride
Nd4jLong extendedStrides[MAX_RANK];
for(int i=0;i<rank; i++){
extendedStrides[i] = inStrides[i];
}
//lets correct new stride
extendedStrides[rank-1] = 8*inStrides[rank-1];
extendedStrides[rank] = inStrides[rank-1];
//general case. its slow. we can improve it for special case later
//generic case that could be further imrpoved. for now its slow
FUNC_1D func = [rank, buff, outBuff, outShapes, extendedStrides, outStrides, threshold]
(uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
Nd4jLong coords[MAX_RANK] = {};
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
//nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop);
auto len = (stop-start);
// its extended as {rank+1} so extendedStrides[rank] is valid
auto innermostStride = extendedStrides[rank];
sd::index2coords_C(start, rank, outShapes, ptr_coords);
//here last dimension will not be in coords. this way output shape and input shapes are equal
auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank);
for(auto k=0; k < len; k++){
auto buffPart = &(buff[offset.first]);
auto outBuffPart = &(outBuff[offset.second]);
*outBuffPart = pack<X>(buffPart, innermostStride, threshold);
offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank);
}
};
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
}
}
}
/////////////////////////////////////////////////////////////
void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) {
BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES);
}
}
}