Compare_and_bitpack: It was reimplemented. now the last dimension should be divisible by 8
Signed-off-by: AbdelRauf <rauf@konduit.ai>
This commit is contained in:
		
							parent
							
								
									b66454d593
								
							
						
					
					
						commit
						fe22bd5726
					
				@ -23,6 +23,7 @@
 | 
				
			|||||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
 | 
					#include <ops/declarable/generic/helpers/BroadcastHelper.h>
 | 
				
			||||||
#include <ops/declarable/headers/parity_ops.h>
 | 
					#include <ops/declarable/headers/parity_ops.h>
 | 
				
			||||||
#include <ops/declarable/headers/datatypes.h>
 | 
					#include <ops/declarable/headers/datatypes.h>
 | 
				
			||||||
 | 
					#include <ops/declarable/helpers/transforms.h>
 | 
				
			||||||
#include <array/NDArrayFactory.h>
 | 
					#include <array/NDArrayFactory.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace sd {
 | 
					namespace sd {
 | 
				
			||||||
@ -31,17 +32,9 @@ namespace sd {
 | 
				
			|||||||
            auto x = INPUT_VARIABLE(0);
 | 
					            auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
            auto y = INPUT_VARIABLE(1);
 | 
					            auto y = INPUT_VARIABLE(1);
 | 
				
			||||||
            auto z = OUTPUT_VARIABLE(0);
 | 
					            auto z = OUTPUT_VARIABLE(0);
 | 
				
			||||||
            auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext());
 | 
					 | 
				
			||||||
            BROADCAST_CHECK_EMPTY(x, y, (&z0));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
 | 
					            sd::ops::helpers::compareAndBitpack(block, *x, *y, *z);
 | 
				
			||||||
            bitcast res;
 | 
					            return Status::OK();
 | 
				
			||||||
            auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false);
 | 
					 | 
				
			||||||
            if (tZ != &z0) {
 | 
					 | 
				
			||||||
                delete tZ;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return status;
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DECLARE_TYPES(compare_and_bitpack) {
 | 
					        DECLARE_TYPES(compare_and_bitpack) {
 | 
				
			||||||
@ -53,9 +46,15 @@ namespace sd {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        DECLARE_SHAPE_FN(compare_and_bitpack) {
 | 
					        DECLARE_SHAPE_FN(compare_and_bitpack) {
 | 
				
			||||||
            auto inShape = inputShape->at(0);
 | 
					            auto inShape = inputShape->at(0);
 | 
				
			||||||
 | 
					            auto shapes = shape::shapeOf(inShape);
 | 
				
			||||||
 | 
					            const int rank = shape::rank(inShape);
 | 
				
			||||||
 | 
					            REQUIRE_TRUE(!shape::isScalar(inShape), 0, "Input should not be a scalar");
 | 
				
			||||||
 | 
					            std::vector<Nd4jLong> shapeDims {shapes, shapes + rank};
 | 
				
			||||||
 | 
					            REQUIRE_TRUE(shapeDims[rank-1] % 8 ==0 , 0, "Last dimension of the input (which is %i) should be divisible by 8 ", shapeDims[rank-1]);
 | 
				
			||||||
 | 
					            shapeDims[rank-1] = shapeDims[rank-1] / 8 ;
 | 
				
			||||||
            DataType newType = DataType::UINT8;
 | 
					            DataType newType = DataType::UINT8;
 | 
				
			||||||
 | 
					            auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(newType, shape::order(inShape), shapeDims);
 | 
				
			||||||
            return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType)));
 | 
					            return SHAPELIST(outputShape);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
@ -1908,15 +1908,15 @@ namespace sd {
 | 
				
			|||||||
        #endif
 | 
					        #endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /**
 | 
					        /**
 | 
				
			||||||
         * compare_and_bitpack - compare with greater and pack result with uint8
 | 
					         * compare_and_bitpack - Compare values of input to threshold and pack resulting bits into a uint8
 | 
				
			||||||
         *
 | 
					         *
 | 
				
			||||||
         * input params:
 | 
					         * input params:
 | 
				
			||||||
         *    0 - NDArray (input)
 | 
					         *    0 - NDArray (input). Note: last dimension should be divisibly by 8
 | 
				
			||||||
         *    1 - 0D Tensor - threshold
 | 
					         *    1 - 0D Tensor - threshold to compare against. Note: when input and threshold is bool type, the threshold is ignored
 | 
				
			||||||
         *
 | 
					         *
 | 
				
			||||||
         *
 | 
					         *
 | 
				
			||||||
         * output:
 | 
					         * output:
 | 
				
			||||||
         *    0 - NDArray with the same shape as input and type uint8
 | 
					         *    0 - NDArray with the shape as {input.dim0,...input.dimLast/8} and type uint8
 | 
				
			||||||
         */
 | 
					         */
 | 
				
			||||||
        #if NOT_EXCLUDED(OP_compare_and_bitpack)
 | 
					        #if NOT_EXCLUDED(OP_compare_and_bitpack)
 | 
				
			||||||
        DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
 | 
					        DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,191 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *  ******************************************************************************
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *  * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *  * information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *  * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *  * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *  * under the License.
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *  *****************************************************************************
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 //
 | 
				
			||||||
 | 
					 // @author AbdelRauf 
 | 
				
			||||||
 | 
					 //
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <type_traits>
 | 
				
			||||||
 | 
					#include <cmath>
 | 
				
			||||||
 | 
					#include <stdexcept>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <execution/Threads.h>
 | 
				
			||||||
 | 
					#include <execution/ThreadPool.h>
 | 
				
			||||||
 | 
					#include <helpers/LoopsCoordsHelper.h>
 | 
				
			||||||
 | 
					#include <ops/declarable/helpers/transforms.h>
 | 
				
			||||||
 | 
					#include <helpers/LoopsCoordsHelper.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace sd {
 | 
				
			||||||
 | 
					    namespace ops {
 | 
				
			||||||
 | 
					        namespace helpers {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            template<typename X>
 | 
				
			||||||
 | 
					            uint8_t pack(const X* buff, const X& threshold){
 | 
				
			||||||
 | 
					                uint8_t res;
 | 
				
			||||||
 | 
					                res = (buff[0] > threshold) << 7;
 | 
				
			||||||
 | 
					                res = res | ((buff[1] > threshold) << 6); 
 | 
				
			||||||
 | 
					                res = res | ((buff[2] > threshold) << 5);
 | 
				
			||||||
 | 
					                res = res | ((buff[3] > threshold) << 4);
 | 
				
			||||||
 | 
					                res = res | ((buff[4] > threshold) << 3);
 | 
				
			||||||
 | 
					                res = res | ((buff[5] > threshold) << 2);
 | 
				
			||||||
 | 
					                res = res | ((buff[6] > threshold) << 1);
 | 
				
			||||||
 | 
					                res = res | (buff[7] > threshold);
 | 
				
			||||||
 | 
					                return res;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            template<>
 | 
				
			||||||
 | 
					            uint8_t pack<bool>(const bool* buff, const bool &threshold){
 | 
				
			||||||
 | 
					                //ignore threshold
 | 
				
			||||||
 | 
					                uint8_t res;
 | 
				
			||||||
 | 
					                res = buff[0] << 7;
 | 
				
			||||||
 | 
					                res = res | (buff[1] << 6); 
 | 
				
			||||||
 | 
					                res = res | (buff[2] << 5);
 | 
				
			||||||
 | 
					                res = res | (buff[3] << 4);
 | 
				
			||||||
 | 
					                res = res | (buff[4] << 3);
 | 
				
			||||||
 | 
					                res = res | (buff[5] << 2);
 | 
				
			||||||
 | 
					                res = res | (buff[6] << 1);
 | 
				
			||||||
 | 
					                res = res | buff[7] ;
 | 
				
			||||||
 | 
					                return res;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            template<typename X>
 | 
				
			||||||
 | 
					            uint8_t pack(const X* buff, int stride, const X& threshold){
 | 
				
			||||||
 | 
					                uint8_t res;
 | 
				
			||||||
 | 
					                res = (buff[0] > threshold) << 7;
 | 
				
			||||||
 | 
					                res = res | ((buff[1*stride] > threshold) << 6); 
 | 
				
			||||||
 | 
					                res = res | ((buff[2*stride] > threshold) << 5);
 | 
				
			||||||
 | 
					                res = res | ((buff[3*stride] > threshold) << 4);
 | 
				
			||||||
 | 
					                res = res | ((buff[4*stride] > threshold) << 3);
 | 
				
			||||||
 | 
					                res = res | ((buff[5*stride] > threshold) << 2);
 | 
				
			||||||
 | 
					                res = res | ((buff[6*stride] > threshold) << 1);
 | 
				
			||||||
 | 
					                res = res | (buff[7*stride] > threshold);
 | 
				
			||||||
 | 
					                return res;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            template<>
 | 
				
			||||||
 | 
					            uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
 | 
				
			||||||
 | 
					                //ignore threshold
 | 
				
			||||||
 | 
					                uint8_t res;
 | 
				
			||||||
 | 
					                res = buff[0] << 7;
 | 
				
			||||||
 | 
					                res = res | (buff[1*stride] << 6); 
 | 
				
			||||||
 | 
					                res = res | (buff[2*stride] << 5);
 | 
				
			||||||
 | 
					                res = res | (buff[3*stride] << 4);
 | 
				
			||||||
 | 
					                res = res | (buff[4*stride] << 3);
 | 
				
			||||||
 | 
					                res = res | (buff[5*stride] << 2);
 | 
				
			||||||
 | 
					                res = res | (buff[6*stride] << 1);
 | 
				
			||||||
 | 
					                res = res | buff[7*stride] ;
 | 
				
			||||||
 | 
					                return res;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            template <typename X>
 | 
				
			||||||
 | 
					            void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    auto rank =input.rankOf();
 | 
				
			||||||
 | 
					                    X threshold = thresholdScalar.e<X>(0);
 | 
				
			||||||
 | 
					                    auto buff = input.bufferAsT<X>();
 | 
				
			||||||
 | 
					                    uint8_t *outBuff = output.bufferAsT<uint8_t>();
 | 
				
			||||||
 | 
					                    if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){
 | 
				
			||||||
 | 
					                        FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
 | 
				
			||||||
 | 
					                                //nd4j_printf("s: %i e: %i \n", (int)start,(int)stop);
 | 
				
			||||||
 | 
					                                auto outBuffPart = outBuff + start;
 | 
				
			||||||
 | 
					                                auto buffPart = buff + start*8;
 | 
				
			||||||
 | 
					                                auto len = stop-start; 
 | 
				
			||||||
 | 
					                                //run 
 | 
				
			||||||
 | 
					                                for(auto i=0;i < len; i++){ 
 | 
				
			||||||
 | 
					                                    outBuffPart[i] = pack<X>(&(buffPart[8*i]), threshold);
 | 
				
			||||||
 | 
					                                }
 | 
				
			||||||
 | 
					                        };
 | 
				
			||||||
 | 
					                        samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    else{
 | 
				
			||||||
 | 
					                        
 | 
				
			||||||
 | 
					                        auto inShapes = input.shapeOf();
 | 
				
			||||||
 | 
					                        auto outShapes = output.shapeOf();
 | 
				
			||||||
 | 
					                        auto inStrides = input.stridesOf();
 | 
				
			||||||
 | 
					                        auto outStrides = output.stridesOf();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        if(rank == 1){
 | 
				
			||||||
 | 
					                            auto inLastStride = inStrides[rank-1];
 | 
				
			||||||
 | 
					                            auto outLastStride = outStrides[rank-1];
 | 
				
			||||||
 | 
					                            FUNC_1D func = [buff, outBuff, inLastStride, outLastStride,  threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
 | 
				
			||||||
 | 
					                                    //nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop);
 | 
				
			||||||
 | 
					                                    auto buffPart = buff + start*8*inLastStride;
 | 
				
			||||||
 | 
					                                    auto outBuffPart = outBuff + start* outLastStride;
 | 
				
			||||||
 | 
					                                    auto len = stop-start; 
 | 
				
			||||||
 | 
					                                    //run 
 | 
				
			||||||
 | 
					                                    for(auto i=0;i < len; i++){ 
 | 
				
			||||||
 | 
					                                        *outBuffPart = pack<X>(buffPart, inLastStride, threshold);
 | 
				
			||||||
 | 
					                                        buffPart += 8*inLastStride;
 | 
				
			||||||
 | 
					                                        outBuffPart += outLastStride;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                    }
 | 
				
			||||||
 | 
					                            };
 | 
				
			||||||
 | 
					                            samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
 | 
				
			||||||
 | 
					                        }else{
 | 
				
			||||||
 | 
					                            //if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
 | 
				
			||||||
 | 
					                            //therefore we can split input shape  {n1, n2, n3 , 8} and correct its stride
 | 
				
			||||||
 | 
					                            //as we do not need last shape info. lets just extend and correct its stride
 | 
				
			||||||
 | 
					                            Nd4jLong extendedStrides[MAX_RANK];
 | 
				
			||||||
 | 
					                            for(int i=0;i<rank; i++){
 | 
				
			||||||
 | 
					                                extendedStrides[i] = inStrides[i];
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                            //lets correct new stride
 | 
				
			||||||
 | 
					                            extendedStrides[rank-1] = 8*inStrides[rank-1];
 | 
				
			||||||
 | 
					                            extendedStrides[rank] = inStrides[rank-1];
 | 
				
			||||||
 | 
					                            //general case. its slow. we can improve it for special case later
 | 
				
			||||||
 | 
					                            //generic case that could be further imrpoved. for now its slow
 | 
				
			||||||
 | 
					                            FUNC_1D func = [rank, buff, outBuff, outShapes, extendedStrides, outStrides, threshold]
 | 
				
			||||||
 | 
					                            (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
 | 
				
			||||||
 | 
					                                    Nd4jLong coords[MAX_RANK] = {};
 | 
				
			||||||
 | 
					                                    Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
 | 
				
			||||||
 | 
					                                    //nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop);
 | 
				
			||||||
 | 
					                                    auto len = (stop-start);
 | 
				
			||||||
 | 
					                                    // its extended as {rank+1} so extendedStrides[rank] is valid 
 | 
				
			||||||
 | 
					                                    auto innermostStride = extendedStrides[rank];
 | 
				
			||||||
 | 
					                                    sd::index2coords_C(start, rank, outShapes, ptr_coords);
 | 
				
			||||||
 | 
					                                    //here last dimension will not be in coords. this way output shape and input shapes are equal
 | 
				
			||||||
 | 
					                                    auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank);
 | 
				
			||||||
 | 
					                                    for(auto k=0; k < len; k++){
 | 
				
			||||||
 | 
					                                        auto buffPart = &(buff[offset.first]);
 | 
				
			||||||
 | 
					                                        auto outBuffPart = &(outBuff[offset.second]);
 | 
				
			||||||
 | 
					                                        *outBuffPart = pack<X>(buffPart, innermostStride, threshold);
 | 
				
			||||||
 | 
					                                        offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank);
 | 
				
			||||||
 | 
					                                    }
 | 
				
			||||||
 | 
					                            };
 | 
				
			||||||
 | 
					                            samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            /////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					            void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) {
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					                BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,191 @@
 | 
				
			|||||||
 | 
					/*
 | 
				
			||||||
 | 
					 *  ******************************************************************************
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 *  * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 *  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * See the NOTICE file distributed with this work for additional
 | 
				
			||||||
 | 
					 *  * information regarding copyright ownership.
 | 
				
			||||||
 | 
					 *  * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 *  * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 *  * under the License.
 | 
				
			||||||
 | 
					 *  *
 | 
				
			||||||
 | 
					 *  * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 *  *****************************************************************************
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 //
 | 
				
			||||||
 | 
					 // @author AbdelRauf 
 | 
				
			||||||
 | 
					 //
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <system/op_boilerplate.h>
 | 
				
			||||||
 | 
					#include <ops/declarable/helpers/imagesHelpers.h>
 | 
				
			||||||
 | 
					#include <helpers/ConstantTadHelper.h>
 | 
				
			||||||
 | 
					#include <ops/declarable/helpers/adjust_hue.h>
 | 
				
			||||||
 | 
					#include <helpers/PointersManager.h>
 | 
				
			||||||
 | 
					#include <ops/declarable/helpers/transforms.h>
 | 
				
			||||||
 | 
					#include <helpers/LoopsCoordsHelper.h>
 | 
				
			||||||
 | 
					namespace sd    {
 | 
				
			||||||
 | 
					namespace ops     {
 | 
				
			||||||
 | 
					namespace helpers {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template<typename X>
 | 
				
			||||||
 | 
					    _CUDA_HD uint8_t pack(const X* buff, const X& threshold){
 | 
				
			||||||
 | 
					        uint8_t res;
 | 
				
			||||||
 | 
					        res = (buff[0] > threshold) << 7;
 | 
				
			||||||
 | 
					        res = res | ((buff[1] > threshold) << 6); 
 | 
				
			||||||
 | 
					        res = res | ((buff[2] > threshold) << 5);
 | 
				
			||||||
 | 
					        res = res | ((buff[3] > threshold) << 4);
 | 
				
			||||||
 | 
					        res = res | ((buff[4] > threshold) << 3);
 | 
				
			||||||
 | 
					        res = res | ((buff[5] > threshold) << 2);
 | 
				
			||||||
 | 
					        res = res | ((buff[6] > threshold) << 1);
 | 
				
			||||||
 | 
					        res = res | (buff[7] > threshold);
 | 
				
			||||||
 | 
					        return res;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template<>
 | 
				
			||||||
 | 
					    _CUDA_HD uint8_t pack<bool>(const bool* buff, const bool &threshold){
 | 
				
			||||||
 | 
					        //ignore threshold
 | 
				
			||||||
 | 
					        uint8_t res;
 | 
				
			||||||
 | 
					        res = buff[0] << 7;
 | 
				
			||||||
 | 
					        res = res | (buff[1] << 6); 
 | 
				
			||||||
 | 
					        res = res | (buff[2] << 5);
 | 
				
			||||||
 | 
					        res = res | (buff[3] << 4);
 | 
				
			||||||
 | 
					        res = res | (buff[4] << 3);
 | 
				
			||||||
 | 
					        res = res | (buff[5] << 2);
 | 
				
			||||||
 | 
					        res = res | (buff[6] << 1);
 | 
				
			||||||
 | 
					        res = res | buff[7] ;
 | 
				
			||||||
 | 
					        return res;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template<typename X>
 | 
				
			||||||
 | 
					    _CUDA_HD uint8_t pack(const X* buff, int stride, const X& threshold){
 | 
				
			||||||
 | 
					        uint8_t res;
 | 
				
			||||||
 | 
					        res = (buff[0] > threshold) << 7;
 | 
				
			||||||
 | 
					        res = res | ((buff[1*stride] > threshold) << 6); 
 | 
				
			||||||
 | 
					        res = res | ((buff[2*stride] > threshold) << 5);
 | 
				
			||||||
 | 
					        res = res | ((buff[3*stride] > threshold) << 4);
 | 
				
			||||||
 | 
					        res = res | ((buff[4*stride] > threshold) << 3);
 | 
				
			||||||
 | 
					        res = res | ((buff[5*stride] > threshold) << 2);
 | 
				
			||||||
 | 
					        res = res | ((buff[6*stride] > threshold) << 1);
 | 
				
			||||||
 | 
					        res = res | (buff[7*stride] > threshold);
 | 
				
			||||||
 | 
					        return res;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template<>
 | 
				
			||||||
 | 
					    _CUDA_HD uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){
 | 
				
			||||||
 | 
					        //ignore threshold
 | 
				
			||||||
 | 
					        uint8_t res;
 | 
				
			||||||
 | 
					        res = buff[0] << 7;
 | 
				
			||||||
 | 
					        res = res | (buff[1*stride] << 6); 
 | 
				
			||||||
 | 
					        res = res | (buff[2*stride] << 5);
 | 
				
			||||||
 | 
					        res = res | (buff[3*stride] << 4);
 | 
				
			||||||
 | 
					        res = res | (buff[4*stride] << 3);
 | 
				
			||||||
 | 
					        res = res | (buff[5*stride] << 2);
 | 
				
			||||||
 | 
					        res = res | (buff[6*stride] << 1);
 | 
				
			||||||
 | 
					        res = res | buff[7*stride] ;
 | 
				
			||||||
 | 
					        return res;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					///////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					static void _CUDA_G cmpBitpack(const void* vx, void* vz,  int rank, int len, const Nd4jLong *xStridesExtended, const Nd4jLong *outPutShapeInfo, T threshold) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const T* x = reinterpret_cast<const T*>(vx);
 | 
				
			||||||
 | 
					    uint8_t* z = reinterpret_cast<uint8_t*>(vz);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
				
			||||||
 | 
					    auto shapes = shape::shapeOf(outPutShapeInfo);
 | 
				
			||||||
 | 
					    auto zStrides = shape::stride(outPutShapeInfo);
 | 
				
			||||||
 | 
					    Nd4jLong coords[MAX_RANK] = {};
 | 
				
			||||||
 | 
					    Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
 | 
				
			||||||
 | 
					    // its extended as {rank+1} so xStridesExtended[rank] is valid 
 | 
				
			||||||
 | 
					    auto inLastStride = xStridesExtended[rank];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
 | 
				
			||||||
 | 
					        sd::index2coords_C(k, rank, shapes, ptr_coords); 
 | 
				
			||||||
 | 
					        auto offset = sd::offset_from_coords(xStridesExtended, zStrides, ptr_coords, rank); 
 | 
				
			||||||
 | 
					        auto buffPart = &(x[offset.first]);
 | 
				
			||||||
 | 
					        auto outBuffPart = &(z[offset.second]);
 | 
				
			||||||
 | 
					        *outBuffPart = pack<T>(buffPart, inLastStride, threshold);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					static void _CUDA_G cmpBitpackEws(const void* vx, void* vz,  int len, const Nd4jLong xStride,  const Nd4jLong yStride,  T threshold) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const T* x = reinterpret_cast<const T*>(vx);
 | 
				
			||||||
 | 
					    uint8_t* z = reinterpret_cast<uint8_t*>(vz);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
				
			||||||
 | 
					    if(xStride==1){
 | 
				
			||||||
 | 
					        for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
 | 
				
			||||||
 | 
					            auto buffPart = &(x[k*8]);
 | 
				
			||||||
 | 
					            auto outBuffPart = &(z[k*yStride]);
 | 
				
			||||||
 | 
					            *outBuffPart = pack<T>(buffPart, threshold); 
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }else{
 | 
				
			||||||
 | 
					        for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){
 | 
				
			||||||
 | 
					            auto buffPart = &(x[k*8*xStride]);
 | 
				
			||||||
 | 
					            auto outBuffPart = &(z[k*yStride]);
 | 
				
			||||||
 | 
					            *outBuffPart =  pack<T>(buffPart, xStride, threshold); 
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					///////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					template<typename T>
 | 
				
			||||||
 | 
					static _CUDA_H void cmpBitpackCudaLauncher(sd::graph::Context& block, const NDArray& input, const NDArray& thresholdScalar, NDArray& output) {
 | 
				
			||||||
 | 
					    T threshold = thresholdScalar.e<T>(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto inStrides = input.stridesOf();
 | 
				
			||||||
 | 
					    auto rank = output.rankOf();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    //threadblock size
 | 
				
			||||||
 | 
					    const int threadsPerBlock = MAX_NUM_THREADS / 2;
 | 
				
			||||||
 | 
					    //grid size
 | 
				
			||||||
 | 
					    const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
 | 
				
			||||||
 | 
					    auto stream = block.launchContext()->getCudaStream();
 | 
				
			||||||
 | 
					    //nd4j_printf("n %i g %i th %i \n", output.lengthOf(), blocksPerGrid, threadsPerBlock);
 | 
				
			||||||
 | 
					    PointersManager manager(block.launchContext(), "compare_and_bitpack");
 | 
				
			||||||
 | 
					    NDArray::prepareSpecialUse({&output}, {&input});
 | 
				
			||||||
 | 
					    if(input.ews()>0 && output.ews()>0 && input.ordering()=='c' && output.ordering()=='c'){
 | 
				
			||||||
 | 
					        cmpBitpackEws<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), output.lengthOf(), inStrides[rank-1], output.stridesOf()[rank-1] , threshold);
 | 
				
			||||||
 | 
					    }else{
 | 
				
			||||||
 | 
					        //if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
 | 
				
			||||||
 | 
					        //therefore we can split input shape  {n1, n2, n3 , 8} and correct its stride
 | 
				
			||||||
 | 
					        //as we do not need last shape info. lets just extend and correct its stride
 | 
				
			||||||
 | 
					        Nd4jLong extendedStrides[MAX_RANK];
 | 
				
			||||||
 | 
					        for(int i=0;i<rank; i++){
 | 
				
			||||||
 | 
					            extendedStrides[i] = inStrides[i];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        //lets correct new stride
 | 
				
			||||||
 | 
					        extendedStrides[rank-1] = 8*inStrides[rank-1];
 | 
				
			||||||
 | 
					        extendedStrides[rank] = inStrides[rank-1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        auto strideSize = (rank+1)*sizeof(Nd4jLong); 
 | 
				
			||||||
 | 
					        Nd4jLong* extendedStridesDevPtr = reinterpret_cast<Nd4jLong*>(manager.replicatePointer(extendedStrides, strideSize));
 | 
				
			||||||
 | 
					        cmpBitpack<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), rank, output.lengthOf(), extendedStridesDevPtr, output.specialShapeInfo(), threshold);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    NDArray::registerSpecialUse({&output}, {&input});
 | 
				
			||||||
 | 
					    manager.synchronize(); 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output)  {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    BUILD_SINGLE_SELECTOR(input.dataType(), cmpBitpackCudaLauncher, (block, input, threshold, output), LIBND4J_TYPES);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -26,7 +26,7 @@
 | 
				
			|||||||
#include <ops/declarable/helpers/helpers.h>
 | 
					#include <ops/declarable/helpers/helpers.h>
 | 
				
			||||||
#include <helpers/helper_random.h>
 | 
					#include <helpers/helper_random.h>
 | 
				
			||||||
#include <graph/RandomGenerator.h>
 | 
					#include <graph/RandomGenerator.h>
 | 
				
			||||||
 | 
					#include <graph/Context.h>
 | 
				
			||||||
namespace sd    {
 | 
					namespace sd    {
 | 
				
			||||||
namespace ops     {
 | 
					namespace ops     {
 | 
				
			||||||
namespace helpers {
 | 
					namespace helpers {
 | 
				
			||||||
@ -84,6 +84,8 @@ namespace helpers {
 | 
				
			|||||||
	void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
 | 
						void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis);
 | 
						void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						void compareAndBitpack(graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -27,7 +27,7 @@
 | 
				
			|||||||
#include <ops/ops.h>
 | 
					#include <ops/ops.h>
 | 
				
			||||||
#include <helpers/GradCheck.h>
 | 
					#include <helpers/GradCheck.h>
 | 
				
			||||||
#include <loops/random.h>
 | 
					#include <loops/random.h>
 | 
				
			||||||
 | 
					#include <array/DataType.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace sd;
 | 
					using namespace sd;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1757,6 +1757,208 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create<float>('c', {2, 3, 16}, {
 | 
				
			||||||
 | 
					    0.865595f, 0.381197f, 0.911656f, 0.256752f, 0.084921f, 0.070434f, 0.469923f, 0.269935f, 0.510656f, 0.949777f, 0.926772f, 0.622540f, 0.688253f, 0.164974f,
 | 
				
			||||||
 | 
					    0.068558f, 0.031173f, 0.910035f, 0.219362f, 0.731336f, 0.135392f, 0.449875f, 0.020135f, 0.891820f, 0.907567f, 0.114376f, 0.652253f, 0.892939f, 0.698095f,
 | 
				
			||||||
 | 
					    0.423831f, 0.971155f, 0.968733f, 0.194465f, 0.852475f, 0.642962f, 0.417665f, 0.768379f, 0.753035f, 0.738440f, 0.046251f, 0.659487f, 0.486230f, 0.246724f,
 | 
				
			||||||
 | 
					    0.276700f, 0.103631f, 0.843105f, 0.562587f, 0.784459f, 0.109871f, 0.455828f, 0.129641f, 0.002471f, 0.148281f, 0.976162f, 0.603573f, 0.752530f, 0.249840f,
 | 
				
			||||||
 | 
					    0.723716f, 0.658430f, 0.661057f, 0.328042f, 0.338351f, 0.903157f, 0.485580f, 0.405103f, 0.335052f, 0.509858f, 0.764852f, 0.764527f, 0.382572f, 0.962121f,
 | 
				
			||||||
 | 
					    0.296145f, 0.602766f, 0.169683f, 0.750371f, 0.993936f, 0.914704f, 0.199342f, 0.858098f, 0.617198f, 0.219334f, 0.167574f, 0.305204f, 0.960773f, 0.537944f,
 | 
				
			||||||
 | 
					    0.245441f, 0.787276f, 0.968920f, 0.980918f, 0.615237f, 0.355165f, 0.480441f, 0.304282f, 0.961229f, 0.639195f, 0.017776f, 0.836153f
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<float>(0.5f);
 | 
				
			||||||
 | 
					    auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op;
 | 
				
			||||||
 | 
					    auto result = op.evaluate({&x, &threshold}, {}, {}, {});
 | 
				
			||||||
 | 
					    auto output = result.at(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ASSERT_EQ(ND4J_STATUS_OK, result.status());
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.isSameShape(output));
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.equalsTo(output));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test2) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create<bool>('c', {2, 3, 16}, {
 | 
				
			||||||
 | 
					        true, false, true, false, false, false, false, false, true,
 | 
				
			||||||
 | 
					        true, true, true, true, false, false, false, true, false,
 | 
				
			||||||
 | 
					        true, false, false, false, true, true, false, true, true,
 | 
				
			||||||
 | 
					        true, false, true, true, false, true, true, false, true,
 | 
				
			||||||
 | 
					        true, true, false, true, false, false, false, false, true,
 | 
				
			||||||
 | 
					        true, true, false, false, false, false, false, true, true,
 | 
				
			||||||
 | 
					        true, false, true, true, true, false, false, true, false,
 | 
				
			||||||
 | 
					        false, false, true, true, true, false, true, false, true,
 | 
				
			||||||
 | 
					        false, true, true, true, false, true, true, false, false,
 | 
				
			||||||
 | 
					        false, true, true, false, true, true, true, true, false,
 | 
				
			||||||
 | 
					        false, false, true, true, false, true
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    //threshold is ignored here ,actually
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<bool>(true);
 | 
				
			||||||
 | 
					    auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op;
 | 
				
			||||||
 | 
					    auto result = op.evaluate({&x, &threshold}, {}, {}, {});
 | 
				
			||||||
 | 
					    auto output = result.at(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ASSERT_EQ(ND4J_STATUS_OK, result.status());
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.isSameShape(output));
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.equalsTo(output));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test3) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 16});
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<float>(0.5f);
 | 
				
			||||||
 | 
					    auto exp = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op;
 | 
				
			||||||
 | 
					    auto result = op.evaluate({&x, &threshold}, {}, {}, {});
 | 
				
			||||||
 | 
					    auto output = result.at(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ASSERT_EQ(ND4J_STATUS_OK, result.status());
 | 
				
			||||||
 | 
					    output->printShapeInfo("output");
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.isSameShape(output));
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.equalsTo(output));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test4) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13});
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<float>(0.5f);
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op; 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ASSERT_THROW(op.evaluate({&x, &threshold}, {}, {}, {}), std::invalid_argument); 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test5) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13});
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<float>(0.5f);
 | 
				
			||||||
 | 
					    auto out =  NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 1});
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op; 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::invalid_argument); 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test6) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 8});
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<float>(0.5f);
 | 
				
			||||||
 | 
					    auto out =  NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2});
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op; 
 | 
				
			||||||
 | 
					    //shape mismatch throws runtime error
 | 
				
			||||||
 | 
					    ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::runtime_error); 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test7) {
 | 
				
			||||||
 | 
					    constexpr int pp = 32*32*16;
 | 
				
			||||||
 | 
					    constexpr int s1 = 3; 
 | 
				
			||||||
 | 
					    constexpr int t1 = 8;
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> shape1 = {pp}; 
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> strides1 = {s1};
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> shape2 = {pp/8}; 
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> strides2 = {t1};
 | 
				
			||||||
 | 
					    ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, s1);
 | 
				
			||||||
 | 
					    ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, t1);
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create(desc1);
 | 
				
			||||||
 | 
					    auto output = NDArrayFactory::create(desc2);
 | 
				
			||||||
 | 
					    auto exp =  NDArrayFactory::create(desc2);
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<bool>(true);
 | 
				
			||||||
 | 
					    auto buff = x.bufferAsT<bool>();
 | 
				
			||||||
 | 
						uint8_t *expBuff = exp.bufferAsT<uint8_t>();
 | 
				
			||||||
 | 
					    //generate test
 | 
				
			||||||
 | 
					    for(int l=0;l<pp; l+=8){
 | 
				
			||||||
 | 
					                uint8_t test =  rand() % 255;
 | 
				
			||||||
 | 
					                expBuff[l/8*t1] = test;
 | 
				
			||||||
 | 
					                auto buffP = &(buff[l*s1]);
 | 
				
			||||||
 | 
					                buffP[0] = test & (1<<7);
 | 
				
			||||||
 | 
					                buffP[1*s1] = test & (1<<6);
 | 
				
			||||||
 | 
					                buffP[2*s1] = test & (1<<5);
 | 
				
			||||||
 | 
					                buffP[3*s1] = test & (1<<4);
 | 
				
			||||||
 | 
					                buffP[4*s1] = test & (1<<3);
 | 
				
			||||||
 | 
					                buffP[5*s1] = test & (1<<2);
 | 
				
			||||||
 | 
					                buffP[6*s1] = test & (1<<1);
 | 
				
			||||||
 | 
					                buffP[7*s1] = test & 1;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    //explicit sync to device
 | 
				
			||||||
 | 
					    x.tickWriteHost();
 | 
				
			||||||
 | 
					    exp.tickWriteHost();
 | 
				
			||||||
 | 
					    x.syncToDevice();
 | 
				
			||||||
 | 
					    exp.syncToDevice();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op;
 | 
				
			||||||
 | 
					    auto result = op.execute({&x, &threshold}, {&output}, {}, {});
 | 
				
			||||||
 | 
					    ASSERT_EQ(Status::OK(), result);
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.isSameShape(&output));
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.equalsTo(&output));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test8) {
 | 
				
			||||||
 | 
					    constexpr int pp = 32;
 | 
				
			||||||
 | 
					    constexpr int s1 = 2;
 | 
				
			||||||
 | 
					    constexpr int s2 = (s1*pp) + 3;
 | 
				
			||||||
 | 
					    constexpr int s3 = (s2*pp) + 4;
 | 
				
			||||||
 | 
					    constexpr int t1 = 2;
 | 
				
			||||||
 | 
					    constexpr int t2 = (t1*pp/8) + 3;
 | 
				
			||||||
 | 
					    constexpr int t3 = (t2*pp) + 4;
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> shape1 = {pp,pp,pp}; 
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> strides1 = {s3 , s2 , s1};
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> shape2 = {pp,pp,pp/8}; 
 | 
				
			||||||
 | 
					    std::vector<Nd4jLong> strides2 = {t3 , t2 , t1};
 | 
				
			||||||
 | 
					    ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, 0);
 | 
				
			||||||
 | 
					    ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, 0);
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create(desc1);
 | 
				
			||||||
 | 
					    auto output =  NDArrayFactory::create(desc2);
 | 
				
			||||||
 | 
					    auto exp =  NDArrayFactory::create(desc2);
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<bool>(true);
 | 
				
			||||||
 | 
					    auto buff = x.bufferAsT<bool>();
 | 
				
			||||||
 | 
						uint8_t *expBuff = exp.bufferAsT<uint8_t>();
 | 
				
			||||||
 | 
					    //generate test
 | 
				
			||||||
 | 
					    for(int i=0;i<pp;i++){
 | 
				
			||||||
 | 
					        for(int j=0;j<pp;j++){
 | 
				
			||||||
 | 
					            for(int l=0;l<pp; l+=8){
 | 
				
			||||||
 | 
					                uint8_t test =  rand() % 255;
 | 
				
			||||||
 | 
					                expBuff[l/8*t1 + j*t2 + i *t3] = test;
 | 
				
			||||||
 | 
					                auto buffP = &(buff[j*s2 + i *s3 + l*s1]);
 | 
				
			||||||
 | 
					                buffP[0] = test & (1<<7);
 | 
				
			||||||
 | 
					                buffP[1*s1] = test & (1<<6);
 | 
				
			||||||
 | 
					                buffP[2*s1] = test & (1<<5);
 | 
				
			||||||
 | 
					                buffP[3*s1] = test & (1<<4);
 | 
				
			||||||
 | 
					                buffP[4*s1] = test & (1<<3);
 | 
				
			||||||
 | 
					                buffP[5*s1] = test & (1<<2);
 | 
				
			||||||
 | 
					                buffP[6*s1] = test & (1<<1);
 | 
				
			||||||
 | 
					                buffP[7*s1] = test & 1;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    //explicit sync to device
 | 
				
			||||||
 | 
					    x.tickWriteHost();
 | 
				
			||||||
 | 
					    exp.tickWriteHost();
 | 
				
			||||||
 | 
					    x.syncToDevice();
 | 
				
			||||||
 | 
					    exp.syncToDevice();
 | 
				
			||||||
 | 
					    sd::ops::compare_and_bitpack op;
 | 
				
			||||||
 | 
					    auto result = op.execute({&x, &threshold}, {&output}, {}, {});
 | 
				
			||||||
 | 
					    ASSERT_EQ(Status::OK(), result);
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.isSameShape(&output));
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.equalsTo(&output));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
////////////////////////////////////////////////////////////////////////////////
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
 | 
					TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1776,25 +1978,6 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
////////////////////////////////////////////////////////////////////////////////
 | 
					 | 
				
			||||||
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
 | 
					 | 
				
			||||||
    auto threshold = NDArrayFactory::create<double>(2.0);
 | 
					 | 
				
			||||||
    auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 | 
					 | 
				
			||||||
                                                                                0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1});
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    sd::ops::compare_and_bitpack op;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto result = op.evaluate({&x, &threshold}, {}, {}, {});
 | 
					 | 
				
			||||||
    ASSERT_EQ(ND4J_STATUS_OK, result.status());
 | 
					 | 
				
			||||||
    auto output = result.at(0);
 | 
					 | 
				
			||||||
//    output->printIndexedBuffer("Packed to uint8");
 | 
					 | 
				
			||||||
    ASSERT_TRUE(exp.isSameShape(output));
 | 
					 | 
				
			||||||
    ASSERT_TRUE(exp.equalsTo(output));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
////////////////////////////////////////////////////////////////////////////////
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
 | 
					TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user