/* * ****************************************************************************** * * * * * * This program and the accompanying materials are made available under the * * terms of the Apache License, Version 2.0 which is available at * * https://www.apache.org/licenses/LICENSE-2.0. * * * * See the NOTICE file distributed with this work for additional * * information regarding copyright ownership. * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * * License for the specific language governing permissions and limitations * * under the License. * * * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ // // @author AbdelRauf // #include #include #include #include #include #include #include #include #include namespace sd { namespace ops { namespace helpers { template uint8_t pack(const X* buff, const X& threshold){ uint8_t res; res = (buff[0] > threshold) << 7; res = res | ((buff[1] > threshold) << 6); res = res | ((buff[2] > threshold) << 5); res = res | ((buff[3] > threshold) << 4); res = res | ((buff[4] > threshold) << 3); res = res | ((buff[5] > threshold) << 2); res = res | ((buff[6] > threshold) << 1); res = res | (buff[7] > threshold); return res; } template<> uint8_t pack(const bool* buff, const bool &threshold){ //ignore threshold uint8_t res; res = buff[0] << 7; res = res | (buff[1] << 6); res = res | (buff[2] << 5); res = res | (buff[3] << 4); res = res | (buff[4] << 3); res = res | (buff[5] << 2); res = res | (buff[6] << 1); res = res | buff[7] ; return res; } template uint8_t pack(const X* buff, int stride, const X& threshold){ uint8_t res; res = (buff[0] > threshold) << 7; res = res | ((buff[1*stride] > threshold) << 6); res = res | ((buff[2*stride] > threshold) << 5); res = res | ((buff[3*stride] > threshold) << 4); res = res | ((buff[4*stride] > threshold) << 3); res = res | ((buff[5*stride] > threshold) << 2); res = res | ((buff[6*stride] > threshold) << 1); res = res | (buff[7*stride] > threshold); return res; } template<> uint8_t pack(const bool* buff, int stride, const bool &threshold){ //ignore threshold uint8_t res; res = buff[0] << 7; res = res | (buff[1*stride] << 6); res = res | (buff[2*stride] << 5); res = res | (buff[3*stride] << 4); res = res | (buff[4*stride] << 3); res = res | (buff[5*stride] << 2); res = res | (buff[6*stride] << 1); res = res | buff[7*stride] ; return res; } template void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) { auto rank =input.rankOf(); X threshold = thresholdScalar.e(0); auto buff = input.bufferAsT(); uint8_t *outBuff = output.bufferAsT(); if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){ FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { //nd4j_printf("s: %i e: %i \n", (int)start,(int)stop); auto outBuffPart = outBuff + start; auto buffPart = buff + start*8; auto len = stop-start; //run for(auto i=0;i < len; i++){ outBuffPart[i] = pack(&(buffPart[8*i]), threshold); } }; samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); } else{ auto inShapes = input.shapeOf(); auto outShapes = output.shapeOf(); auto inStrides = input.stridesOf(); auto outStrides = output.stridesOf(); if(rank == 1){ auto inLastStride = inStrides[rank-1]; auto outLastStride = outStrides[rank-1]; FUNC_1D func = [buff, outBuff, inLastStride, outLastStride, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { //nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop); auto buffPart = buff + start*8*inLastStride; auto outBuffPart = outBuff + start* outLastStride; auto len = stop-start; //run for(auto i=0;i < len; i++){ *outBuffPart = pack(buffPart, inLastStride, threshold); buffPart += 8*inLastStride; outBuffPart += outLastStride; } }; samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); }else{ //if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8} //therefore we can split input shape {n1, n2, n3 , 8} and correct its stride //as we do not need last shape info. lets just extend and correct its stride Nd4jLong extendedStrides[MAX_RANK]; for(int i=0;i void { Nd4jLong coords[MAX_RANK] = {}; Nd4jLong* ptr_coords = (Nd4jLong*)&coords; //nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop); auto len = (stop-start); // its extended as {rank+1} so extendedStrides[rank] is valid auto innermostStride = extendedStrides[rank]; sd::index2coords_C(start, rank, outShapes, ptr_coords); //here last dimension will not be in coords. this way output shape and input shapes are equal auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank); for(auto k=0; k < len; k++){ auto buffPart = &(buff[offset.first]); auto outBuffPart = &(outBuff[offset.second]); *outBuffPart = pack(buffPart, innermostStride, threshold); offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank); } }; samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); } } } ///////////////////////////////////////////////////////////// void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) { BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES); } BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES); } } }