192 lines
9.4 KiB
C++
192 lines
9.4 KiB
C++
|
/*
|
||
|
* ******************************************************************************
|
||
|
* *
|
||
|
* *
|
||
|
* * This program and the accompanying materials are made available under the
|
||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
* *
|
||
|
* * See the NOTICE file distributed with this work for additional
|
||
|
* * information regarding copyright ownership.
|
||
|
* * Unless required by applicable law or agreed to in writing, software
|
||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* * License for the specific language governing permissions and limitations
|
||
|
* * under the License.
|
||
|
* *
|
||
|
* * SPDX-License-Identifier: Apache-2.0
|
||
|
* *****************************************************************************
|
||
|
*/
|
||
|
|
||
|
//
|
||
|
// @author AbdelRauf
|
||
|
//
|
||
|
|
||
|
#include <type_traits>
|
||
|
#include <cmath>
|
||
|
#include <stdexcept>
|
||
|
#include <memory>
|
||
|
#include <execution/Threads.h>
|
||
|
#include <execution/ThreadPool.h>
|
||
|
#include <helpers/LoopsCoordsHelper.h>
|
||
|
#include <ops/declarable/helpers/transforms.h>
|
||
|
#include <helpers/LoopsCoordsHelper.h>
|
||
|
|
||
|
namespace sd {
|
||
|
namespace ops {
|
||
|
namespace helpers {
|
||
|
|
||
|
|
||
|
template<typename X>
|
||
|
uint8_t pack(const X* buff, const X& threshold){
|
||
|
uint8_t res;
|
||
|
res = (buff[0] > threshold) << 7;
|
||
|
res = res | ((buff[1] > threshold) << 6);
|
||
|
res = res | ((buff[2] > threshold) << 5);
|
||
|
res = res | ((buff[3] > threshold) << 4);
|
||
|
res = res | ((buff[4] > threshold) << 3);
|
||
|
res = res | ((buff[5] > threshold) << 2);
|
||
|
res = res | ((buff[6] > threshold) << 1);
|
||
|
res = res | (buff[7] > threshold);
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
template<>
|
||
|
uint8_t pack<bool>(const bool* buff, const bool &threshold){
|
||
|
//ignore threshold
|
||
|
uint8_t res;
|
||
|
res = buff[0] << 7;
|
||
|
res = res | (buff[1] << 6);
|
||
|
res = res | (buff[2] << 5);
|
||
|
res = res | (buff[3] << 4);
|
||
|
res = res | (buff[4] << 3);
|
||
|
res = res | (buff[5] << 2);
|
||
|
res = res | (buff[6] << 1);
|
||
|
res = res | buff[7] ;
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
template<typename X>
|
||
|
uint8_t pack(const X* buff, int stride, const X& threshold){
|
||
|
uint8_t res;
|
||
|
res = (buff[0] > threshold) << 7;
|
||
|
res = res | ((buff[1*stride] > threshold) << 6);
|
||
|
res = res | ((buff[2*stride] > threshold) << 5);
|
||
|
res = res | ((buff[3*stride] > threshold) << 4);
|
||
|
res = res | ((buff[4*stride] > threshold) << 3);
|
||
|
res = res | ((buff[5*stride] > threshold) << 2);
|
||
|
res = res | ((buff[6*stride] > threshold) << 1);
|
||
|
res = res | (buff[7*stride] > threshold);
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
template<>
|
||
|
uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
|
||
|
//ignore threshold
|
||
|
uint8_t res;
|
||
|
res = buff[0] << 7;
|
||
|
res = res | (buff[1*stride] << 6);
|
||
|
res = res | (buff[2*stride] << 5);
|
||
|
res = res | (buff[3*stride] << 4);
|
||
|
res = res | (buff[4*stride] << 3);
|
||
|
res = res | (buff[5*stride] << 2);
|
||
|
res = res | (buff[6*stride] << 1);
|
||
|
res = res | buff[7*stride] ;
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
|
||
|
template <typename X>
|
||
|
void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
|
||
|
|
||
|
auto rank =input.rankOf();
|
||
|
X threshold = thresholdScalar.e<X>(0);
|
||
|
auto buff = input.bufferAsT<X>();
|
||
|
uint8_t *outBuff = output.bufferAsT<uint8_t>();
|
||
|
if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){
|
||
|
FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||
|
//nd4j_printf("s: %i e: %i \n", (int)start,(int)stop);
|
||
|
auto outBuffPart = outBuff + start;
|
||
|
auto buffPart = buff + start*8;
|
||
|
auto len = stop-start;
|
||
|
//run
|
||
|
for(auto i=0;i < len; i++){
|
||
|
outBuffPart[i] = pack<X>(&(buffPart[8*i]), threshold);
|
||
|
}
|
||
|
};
|
||
|
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||
|
|
||
|
}
|
||
|
else{
|
||
|
|
||
|
auto inShapes = input.shapeOf();
|
||
|
auto outShapes = output.shapeOf();
|
||
|
auto inStrides = input.stridesOf();
|
||
|
auto outStrides = output.stridesOf();
|
||
|
|
||
|
if(rank == 1){
|
||
|
auto inLastStride = inStrides[rank-1];
|
||
|
auto outLastStride = outStrides[rank-1];
|
||
|
FUNC_1D func = [buff, outBuff, inLastStride, outLastStride, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||
|
//nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop);
|
||
|
auto buffPart = buff + start*8*inLastStride;
|
||
|
auto outBuffPart = outBuff + start* outLastStride;
|
||
|
auto len = stop-start;
|
||
|
//run
|
||
|
for(auto i=0;i < len; i++){
|
||
|
*outBuffPart = pack<X>(buffPart, inLastStride, threshold);
|
||
|
buffPart += 8*inLastStride;
|
||
|
outBuffPart += outLastStride;
|
||
|
|
||
|
}
|
||
|
};
|
||
|
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||
|
}else{
|
||
|
//if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
|
||
|
//therefore we can split input shape {n1, n2, n3 , 8} and correct its stride
|
||
|
//as we do not need last shape info. lets just extend and correct its stride
|
||
|
Nd4jLong extendedStrides[MAX_RANK];
|
||
|
for(int i=0;i<rank; i++){
|
||
|
extendedStrides[i] = inStrides[i];
|
||
|
}
|
||
|
//lets correct new stride
|
||
|
extendedStrides[rank-1] = 8*inStrides[rank-1];
|
||
|
extendedStrides[rank] = inStrides[rank-1];
|
||
|
//general case. its slow. we can improve it for special case later
|
||
|
//generic case that could be further imrpoved. for now its slow
|
||
|
FUNC_1D func = [rank, buff, outBuff, outShapes, extendedStrides, outStrides, threshold]
|
||
|
(uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||
|
Nd4jLong coords[MAX_RANK] = {};
|
||
|
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
|
||
|
//nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop);
|
||
|
auto len = (stop-start);
|
||
|
// its extended as {rank+1} so extendedStrides[rank] is valid
|
||
|
auto innermostStride = extendedStrides[rank];
|
||
|
sd::index2coords_C(start, rank, outShapes, ptr_coords);
|
||
|
//here last dimension will not be in coords. this way output shape and input shapes are equal
|
||
|
auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank);
|
||
|
for(auto k=0; k < len; k++){
|
||
|
auto buffPart = &(buff[offset.first]);
|
||
|
auto outBuffPart = &(outBuff[offset.second]);
|
||
|
*outBuffPart = pack<X>(buffPart, innermostStride, threshold);
|
||
|
offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank);
|
||
|
}
|
||
|
};
|
||
|
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/////////////////////////////////////////////////////////////
|
||
|
void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) {
|
||
|
|
||
|
BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES);
|
||
|
}
|
||
|
|
||
|
BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES);
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|