Compare_and_bitpack: It was reimplemented. now the last dimension should be divisible by 8
Signed-off-by: AbdelRauf <rauf@konduit.ai>
This commit is contained in:
		
							parent
							
								
									b66454d593
								
							
						
					
					
						commit
						fe22bd5726
					
				| @ -23,6 +23,7 @@ | ||||
| #include <ops/declarable/generic/helpers/BroadcastHelper.h> | ||||
| #include <ops/declarable/headers/parity_ops.h> | ||||
| #include <ops/declarable/headers/datatypes.h> | ||||
| #include <ops/declarable/helpers/transforms.h> | ||||
| #include <array/NDArrayFactory.h> | ||||
| 
 | ||||
| namespace sd { | ||||
| @ -31,17 +32,9 @@ namespace sd { | ||||
|             auto x = INPUT_VARIABLE(0); | ||||
|             auto y = INPUT_VARIABLE(1); | ||||
|             auto z = OUTPUT_VARIABLE(0); | ||||
|             auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext()); | ||||
|             BROADCAST_CHECK_EMPTY(x, y, (&z0)); | ||||
| 
 | ||||
|             auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); | ||||
|             bitcast res; | ||||
|             auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false); | ||||
|             if (tZ != &z0) { | ||||
|                 delete tZ; | ||||
|             } | ||||
| 
 | ||||
|             return status; | ||||
|             sd::ops::helpers::compareAndBitpack(block, *x, *y, *z); | ||||
|             return Status::OK(); | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_TYPES(compare_and_bitpack) { | ||||
| @ -53,9 +46,15 @@ namespace sd { | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(compare_and_bitpack) { | ||||
|             auto inShape = inputShape->at(0); | ||||
|             auto shapes = shape::shapeOf(inShape); | ||||
|             const int rank = shape::rank(inShape); | ||||
|             REQUIRE_TRUE(!shape::isScalar(inShape), 0, "Input should not be a scalar"); | ||||
|             std::vector<Nd4jLong> shapeDims {shapes, shapes + rank}; | ||||
|             REQUIRE_TRUE(shapeDims[rank-1] % 8 ==0 , 0, "Last dimension of the input (which is %i) should be divisible by 8 ", shapeDims[rank-1]); | ||||
|             shapeDims[rank-1] = shapeDims[rank-1] / 8 ; | ||||
|             DataType newType = DataType::UINT8; | ||||
| 
 | ||||
|             return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(inShape, newType))); | ||||
|             auto outputShape = ConstantShapeHelper::getInstance().createShapeInfo(newType, shape::order(inShape), shapeDims); | ||||
|             return SHAPELIST(outputShape); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
|  | ||||
| @ -1908,15 +1908,15 @@ namespace sd { | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|          * compare_and_bitpack - compare with greater and pack result with uint8 | ||||
|          * compare_and_bitpack - Compare values of input to threshold and pack resulting bits into a uint8 | ||||
|          * | ||||
|          * input params: | ||||
|          *    0 - NDArray (input) | ||||
|          *    1 - 0D Tensor - threshold | ||||
|          *    0 - NDArray (input). Note: last dimension should be divisibly by 8 | ||||
|          *    1 - 0D Tensor - threshold to compare against. Note: when input and threshold is bool type, the threshold is ignored | ||||
|          * | ||||
|          * | ||||
|          * output: | ||||
|          *    0 - NDArray with the same shape as input and type uint8 | ||||
|          *    0 - NDArray with the shape as {input.dim0,...input.dimLast/8} and type uint8 | ||||
|          */ | ||||
|         #if NOT_EXCLUDED(OP_compare_and_bitpack) | ||||
|         DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); | ||||
|  | ||||
| @ -0,0 +1,191 @@ | ||||
| /*
 | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||
|  *  * | ||||
|  *  * See the NOTICE file distributed with this work for additional | ||||
|  *  * information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
|  //
 | ||||
|  // @author AbdelRauf 
 | ||||
|  //
 | ||||
| 
 | ||||
| #include <type_traits> | ||||
| #include <cmath> | ||||
| #include <stdexcept> | ||||
| #include <memory> | ||||
| #include <execution/Threads.h> | ||||
| #include <execution/ThreadPool.h> | ||||
| #include <helpers/LoopsCoordsHelper.h> | ||||
| #include <ops/declarable/helpers/transforms.h> | ||||
| #include <helpers/LoopsCoordsHelper.h> | ||||
| 
 | ||||
| namespace sd { | ||||
|     namespace ops { | ||||
|         namespace helpers { | ||||
| 
 | ||||
| 
 | ||||
|             template<typename X> | ||||
|             uint8_t pack(const X* buff, const X& threshold){ | ||||
|                 uint8_t res; | ||||
|                 res = (buff[0] > threshold) << 7; | ||||
|                 res = res | ((buff[1] > threshold) << 6);  | ||||
|                 res = res | ((buff[2] > threshold) << 5); | ||||
|                 res = res | ((buff[3] > threshold) << 4); | ||||
|                 res = res | ((buff[4] > threshold) << 3); | ||||
|                 res = res | ((buff[5] > threshold) << 2); | ||||
|                 res = res | ((buff[6] > threshold) << 1); | ||||
|                 res = res | (buff[7] > threshold); | ||||
|                 return res; | ||||
|             } | ||||
| 
 | ||||
|             template<> | ||||
|             uint8_t pack<bool>(const bool* buff, const bool &threshold){ | ||||
|                 //ignore threshold
 | ||||
|                 uint8_t res; | ||||
|                 res = buff[0] << 7; | ||||
|                 res = res | (buff[1] << 6);  | ||||
|                 res = res | (buff[2] << 5); | ||||
|                 res = res | (buff[3] << 4); | ||||
|                 res = res | (buff[4] << 3); | ||||
|                 res = res | (buff[5] << 2); | ||||
|                 res = res | (buff[6] << 1); | ||||
|                 res = res | buff[7] ; | ||||
|                 return res; | ||||
|             } | ||||
| 
 | ||||
|             template<typename X> | ||||
|             uint8_t pack(const X* buff, int stride, const X& threshold){ | ||||
|                 uint8_t res; | ||||
|                 res = (buff[0] > threshold) << 7; | ||||
|                 res = res | ((buff[1*stride] > threshold) << 6);  | ||||
|                 res = res | ((buff[2*stride] > threshold) << 5); | ||||
|                 res = res | ((buff[3*stride] > threshold) << 4); | ||||
|                 res = res | ((buff[4*stride] > threshold) << 3); | ||||
|                 res = res | ((buff[5*stride] > threshold) << 2); | ||||
|                 res = res | ((buff[6*stride] > threshold) << 1); | ||||
|                 res = res | (buff[7*stride] > threshold); | ||||
|                 return res; | ||||
|             } | ||||
| 
 | ||||
|             template<> | ||||
|             uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){ | ||||
|                 //ignore threshold
 | ||||
|                 uint8_t res; | ||||
|                 res = buff[0] << 7; | ||||
|                 res = res | (buff[1*stride] << 6);  | ||||
|                 res = res | (buff[2*stride] << 5); | ||||
|                 res = res | (buff[3*stride] << 4); | ||||
|                 res = res | (buff[4*stride] << 3); | ||||
|                 res = res | (buff[5*stride] << 2); | ||||
|                 res = res | (buff[6*stride] << 1); | ||||
|                 res = res | buff[7*stride] ; | ||||
|                 return res; | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|             template <typename X> | ||||
|             void compareAndBitpack_(const NDArray& input, const NDArray& thresholdScalar, NDArray& output) { | ||||
| 
 | ||||
|                     auto rank =input.rankOf(); | ||||
|                     X threshold = thresholdScalar.e<X>(0); | ||||
|                     auto buff = input.bufferAsT<X>(); | ||||
|                     uint8_t *outBuff = output.bufferAsT<uint8_t>(); | ||||
|                     if(input.ordering()=='c' && output.ordering()=='c' && input.ews()==1 && output.ews()==1){ | ||||
|                         FUNC_1D func = [buff, outBuff, threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { | ||||
|                                 //nd4j_printf("s: %i e: %i \n", (int)start,(int)stop);
 | ||||
|                                 auto outBuffPart = outBuff + start; | ||||
|                                 auto buffPart = buff + start*8; | ||||
|                                 auto len = stop-start;  | ||||
|                                 //run 
 | ||||
|                                 for(auto i=0;i < len; i++){  | ||||
|                                     outBuffPart[i] = pack<X>(&(buffPart[8*i]), threshold); | ||||
|                                 } | ||||
|                         }; | ||||
|                         samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); | ||||
| 
 | ||||
|                     } | ||||
|                     else{ | ||||
|                          | ||||
|                         auto inShapes = input.shapeOf(); | ||||
|                         auto outShapes = output.shapeOf(); | ||||
|                         auto inStrides = input.stridesOf(); | ||||
|                         auto outStrides = output.stridesOf(); | ||||
| 
 | ||||
|                         if(rank == 1){ | ||||
|                             auto inLastStride = inStrides[rank-1]; | ||||
|                             auto outLastStride = outStrides[rank-1]; | ||||
|                             FUNC_1D func = [buff, outBuff, inLastStride, outLastStride,  threshold](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { | ||||
|                                     //nd4j_printf("rankkk s: %i e: %i \n", (int)start,(int)stop);
 | ||||
|                                     auto buffPart = buff + start*8*inLastStride; | ||||
|                                     auto outBuffPart = outBuff + start* outLastStride; | ||||
|                                     auto len = stop-start;  | ||||
|                                     //run 
 | ||||
|                                     for(auto i=0;i < len; i++){  | ||||
|                                         *outBuffPart = pack<X>(buffPart, inLastStride, threshold); | ||||
|                                         buffPart += 8*inLastStride; | ||||
|                                         outBuffPart += outLastStride; | ||||
| 
 | ||||
|                                     } | ||||
|                             }; | ||||
|                             samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); | ||||
|                         }else{ | ||||
|                             //if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8}
 | ||||
|                             //therefore we can split input shape  {n1, n2, n3 , 8} and correct its stride
 | ||||
|                             //as we do not need last shape info. lets just extend and correct its stride
 | ||||
|                             Nd4jLong extendedStrides[MAX_RANK]; | ||||
|                             for(int i=0;i<rank; i++){ | ||||
|                                 extendedStrides[i] = inStrides[i]; | ||||
|                             } | ||||
|                             //lets correct new stride
 | ||||
|                             extendedStrides[rank-1] = 8*inStrides[rank-1]; | ||||
|                             extendedStrides[rank] = inStrides[rank-1]; | ||||
|                             //general case. its slow. we can improve it for special case later
 | ||||
|                             //generic case that could be further imrpoved. for now its slow
 | ||||
|                             FUNC_1D func = [rank, buff, outBuff, outShapes, extendedStrides, outStrides, threshold] | ||||
|                             (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { | ||||
|                                     Nd4jLong coords[MAX_RANK] = {}; | ||||
|                                     Nd4jLong* ptr_coords = (Nd4jLong*)&coords; | ||||
|                                     //nd4j_printf("generic s: %i e: %i \n", (int)start,(int)stop);
 | ||||
|                                     auto len = (stop-start); | ||||
|                                     // its extended as {rank+1} so extendedStrides[rank] is valid 
 | ||||
|                                     auto innermostStride = extendedStrides[rank]; | ||||
|                                     sd::index2coords_C(start, rank, outShapes, ptr_coords); | ||||
|                                     //here last dimension will not be in coords. this way output shape and input shapes are equal
 | ||||
|                                     auto offset = sd::offset_from_coords(extendedStrides, outStrides, ptr_coords, rank); | ||||
|                                     for(auto k=0; k < len; k++){ | ||||
|                                         auto buffPart = &(buff[offset.first]); | ||||
|                                         auto outBuffPart = &(outBuff[offset.second]); | ||||
|                                         *outBuffPart = pack<X>(buffPart, innermostStride, threshold); | ||||
|                                         offset = inc_coords(outShapes, extendedStrides, outStrides, ptr_coords, offset, rank); | ||||
|                                     } | ||||
|                             }; | ||||
|                             samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); | ||||
|                         } | ||||
| 
 | ||||
|                     } | ||||
|             } | ||||
| 
 | ||||
|             /////////////////////////////////////////////////////////////
 | ||||
|             void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output) { | ||||
|   | ||||
|                 BUILD_SINGLE_SELECTOR(input.dataType(), compareAndBitpack_, (input, threshold, output), LIBND4J_TYPES); | ||||
|             } | ||||
| 
 | ||||
|             BUILD_SINGLE_TEMPLATE(template void compareAndBitpack_, (const NDArray& input, const NDArray& threshold, NDArray& output), LIBND4J_TYPES); | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,191 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  * See the NOTICE file distributed with this work for additional | ||||
|  *  * information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
|  // | ||||
|  // @author AbdelRauf  | ||||
|  // | ||||
| 
 | ||||
| #include <system/op_boilerplate.h> | ||||
| #include <ops/declarable/helpers/imagesHelpers.h> | ||||
| #include <helpers/ConstantTadHelper.h> | ||||
| #include <ops/declarable/helpers/adjust_hue.h> | ||||
| #include <helpers/PointersManager.h> | ||||
| #include <ops/declarable/helpers/transforms.h> | ||||
| #include <helpers/LoopsCoordsHelper.h> | ||||
| namespace sd    { | ||||
| namespace ops     { | ||||
| namespace helpers { | ||||
| 
 | ||||
|     template<typename X> | ||||
|     _CUDA_HD uint8_t pack(const X* buff, const X& threshold){ | ||||
|         uint8_t res; | ||||
|         res = (buff[0] > threshold) << 7; | ||||
|         res = res | ((buff[1] > threshold) << 6);  | ||||
|         res = res | ((buff[2] > threshold) << 5); | ||||
|         res = res | ((buff[3] > threshold) << 4); | ||||
|         res = res | ((buff[4] > threshold) << 3); | ||||
|         res = res | ((buff[5] > threshold) << 2); | ||||
|         res = res | ((buff[6] > threshold) << 1); | ||||
|         res = res | (buff[7] > threshold); | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     template<> | ||||
|     _CUDA_HD uint8_t pack<bool>(const bool* buff, const bool &threshold){ | ||||
|         //ignore threshold | ||||
|         uint8_t res; | ||||
|         res = buff[0] << 7; | ||||
|         res = res | (buff[1] << 6);  | ||||
|         res = res | (buff[2] << 5); | ||||
|         res = res | (buff[3] << 4); | ||||
|         res = res | (buff[4] << 3); | ||||
|         res = res | (buff[5] << 2); | ||||
|         res = res | (buff[6] << 1); | ||||
|         res = res | buff[7] ; | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     template<typename X> | ||||
|     _CUDA_HD uint8_t pack(const X* buff, int stride, const X& threshold){ | ||||
|         uint8_t res; | ||||
|         res = (buff[0] > threshold) << 7; | ||||
|         res = res | ((buff[1*stride] > threshold) << 6);  | ||||
|         res = res | ((buff[2*stride] > threshold) << 5); | ||||
|         res = res | ((buff[3*stride] > threshold) << 4); | ||||
|         res = res | ((buff[4*stride] > threshold) << 3); | ||||
|         res = res | ((buff[5*stride] > threshold) << 2); | ||||
|         res = res | ((buff[6*stride] > threshold) << 1); | ||||
|         res = res | (buff[7*stride] > threshold); | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     template<> | ||||
|     _CUDA_HD uint8_t pack<bool>(const bool* buff, int stride, const bool &threshold){ | ||||
|         //ignore threshold | ||||
|         uint8_t res; | ||||
|         res = buff[0] << 7; | ||||
|         res = res | (buff[1*stride] << 6);  | ||||
|         res = res | (buff[2*stride] << 5); | ||||
|         res = res | (buff[3*stride] << 4); | ||||
|         res = res | (buff[4*stride] << 3); | ||||
|         res = res | (buff[5*stride] << 2); | ||||
|         res = res | (buff[6*stride] << 1); | ||||
|         res = res | buff[7*stride] ; | ||||
|         return res; | ||||
|     } | ||||
| /////////////////////////////////////////////////////////////////// | ||||
| template <typename T> | ||||
| static void _CUDA_G cmpBitpack(const void* vx, void* vz,  int rank, int len, const Nd4jLong *xStridesExtended, const Nd4jLong *outPutShapeInfo, T threshold) { | ||||
| 
 | ||||
|     const T* x = reinterpret_cast<const T*>(vx); | ||||
|     uint8_t* z = reinterpret_cast<uint8_t*>(vz); | ||||
| 
 | ||||
|     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|     auto shapes = shape::shapeOf(outPutShapeInfo); | ||||
|     auto zStrides = shape::stride(outPutShapeInfo); | ||||
|     Nd4jLong coords[MAX_RANK] = {}; | ||||
|     Nd4jLong* ptr_coords = (Nd4jLong*)&coords; | ||||
|     // its extended as {rank+1} so xStridesExtended[rank] is valid  | ||||
|     auto inLastStride = xStridesExtended[rank]; | ||||
| 
 | ||||
|     for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){ | ||||
|         sd::index2coords_C(k, rank, shapes, ptr_coords);  | ||||
|         auto offset = sd::offset_from_coords(xStridesExtended, zStrides, ptr_coords, rank);  | ||||
|         auto buffPart = &(x[offset.first]); | ||||
|         auto outBuffPart = &(z[offset.second]); | ||||
|         *outBuffPart = pack<T>(buffPart, inLastStride, threshold); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| static void _CUDA_G cmpBitpackEws(const void* vx, void* vz,  int len, const Nd4jLong xStride,  const Nd4jLong yStride,  T threshold) { | ||||
| 
 | ||||
|     const T* x = reinterpret_cast<const T*>(vx); | ||||
|     uint8_t* z = reinterpret_cast<uint8_t*>(vz); | ||||
| 
 | ||||
|     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|     if(xStride==1){ | ||||
|         for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){ | ||||
|             auto buffPart = &(x[k*8]); | ||||
|             auto outBuffPart = &(z[k*yStride]); | ||||
|             *outBuffPart = pack<T>(buffPart, threshold);  | ||||
|         } | ||||
|     }else{ | ||||
|         for(auto k=tid; k < len; k+=gridDim.x * blockDim.x){ | ||||
|             auto buffPart = &(x[k*8*xStride]); | ||||
|             auto outBuffPart = &(z[k*yStride]); | ||||
|             *outBuffPart =  pack<T>(buffPart, xStride, threshold);  | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /////////////////////////////////////////////////////////////////// | ||||
| template<typename T> | ||||
| static _CUDA_H void cmpBitpackCudaLauncher(sd::graph::Context& block, const NDArray& input, const NDArray& thresholdScalar, NDArray& output) { | ||||
|     T threshold = thresholdScalar.e<T>(0); | ||||
| 
 | ||||
| 
 | ||||
|     auto inStrides = input.stridesOf(); | ||||
|     auto rank = output.rankOf(); | ||||
| 
 | ||||
|     //threadblock size | ||||
|     const int threadsPerBlock = MAX_NUM_THREADS / 2; | ||||
|     //grid size | ||||
|     const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; | ||||
|     auto stream = block.launchContext()->getCudaStream(); | ||||
|     //nd4j_printf("n %i g %i th %i \n", output.lengthOf(), blocksPerGrid, threadsPerBlock); | ||||
|     PointersManager manager(block.launchContext(), "compare_and_bitpack"); | ||||
|     NDArray::prepareSpecialUse({&output}, {&input}); | ||||
|     if(input.ews()>0 && output.ews()>0 && input.ordering()=='c' && output.ordering()=='c'){ | ||||
|         cmpBitpackEws<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), output.lengthOf(), inStrides[rank-1], output.stridesOf()[rank-1] , threshold); | ||||
|     }else{ | ||||
|         //if output shape is {n1, n2, n3} then input shape is { n1. n2, n3 * 8} | ||||
|         //therefore we can split input shape  {n1, n2, n3 , 8} and correct its stride | ||||
|         //as we do not need last shape info. lets just extend and correct its stride | ||||
|         Nd4jLong extendedStrides[MAX_RANK]; | ||||
|         for(int i=0;i<rank; i++){ | ||||
|             extendedStrides[i] = inStrides[i]; | ||||
|         } | ||||
|         //lets correct new stride | ||||
|         extendedStrides[rank-1] = 8*inStrides[rank-1]; | ||||
|         extendedStrides[rank] = inStrides[rank-1]; | ||||
| 
 | ||||
|         auto strideSize = (rank+1)*sizeof(Nd4jLong);  | ||||
|         Nd4jLong* extendedStridesDevPtr = reinterpret_cast<Nd4jLong*>(manager.replicatePointer(extendedStrides, strideSize)); | ||||
|         cmpBitpack<T><<<blocksPerGrid, threadsPerBlock >>>(input.specialBuffer(), output.specialBuffer(), rank, output.lengthOf(), extendedStridesDevPtr, output.specialShapeInfo(), threshold); | ||||
|     } | ||||
| 
 | ||||
|     NDArray::registerSpecialUse({&output}, {&input}); | ||||
|     manager.synchronize();  | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void compareAndBitpack(sd::graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output)  { | ||||
| 
 | ||||
|     BUILD_SINGLE_SELECTOR(input.dataType(), cmpBitpackCudaLauncher, (block, input, threshold, output), LIBND4J_TYPES); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| } | ||||
| } | ||||
| 
 | ||||
| @ -26,7 +26,7 @@ | ||||
| #include <ops/declarable/helpers/helpers.h> | ||||
| #include <helpers/helper_random.h> | ||||
| #include <graph/RandomGenerator.h> | ||||
| 
 | ||||
| #include <graph/Context.h> | ||||
| namespace sd    { | ||||
| namespace ops     { | ||||
| namespace helpers { | ||||
| @ -84,6 +84,8 @@ namespace helpers { | ||||
| 	void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps); | ||||
| 
 | ||||
| 	void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis); | ||||
| 
 | ||||
| 	void compareAndBitpack(graph::Context& block, const NDArray& input, const NDArray& threshold, NDArray& output); | ||||
| } | ||||
| } | ||||
| } | ||||
|  | ||||
| @ -27,7 +27,7 @@ | ||||
| #include <ops/ops.h> | ||||
| #include <helpers/GradCheck.h> | ||||
| #include <loops/random.h> | ||||
| 
 | ||||
| #include <array/DataType.h> | ||||
| 
 | ||||
| using namespace sd; | ||||
| 
 | ||||
| @ -1757,6 +1757,208 @@ TEST_F(DeclarableOpsTests9, prelu_test14) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 3, 16}, { | ||||
|     0.865595f, 0.381197f, 0.911656f, 0.256752f, 0.084921f, 0.070434f, 0.469923f, 0.269935f, 0.510656f, 0.949777f, 0.926772f, 0.622540f, 0.688253f, 0.164974f, | ||||
|     0.068558f, 0.031173f, 0.910035f, 0.219362f, 0.731336f, 0.135392f, 0.449875f, 0.020135f, 0.891820f, 0.907567f, 0.114376f, 0.652253f, 0.892939f, 0.698095f, | ||||
|     0.423831f, 0.971155f, 0.968733f, 0.194465f, 0.852475f, 0.642962f, 0.417665f, 0.768379f, 0.753035f, 0.738440f, 0.046251f, 0.659487f, 0.486230f, 0.246724f, | ||||
|     0.276700f, 0.103631f, 0.843105f, 0.562587f, 0.784459f, 0.109871f, 0.455828f, 0.129641f, 0.002471f, 0.148281f, 0.976162f, 0.603573f, 0.752530f, 0.249840f, | ||||
|     0.723716f, 0.658430f, 0.661057f, 0.328042f, 0.338351f, 0.903157f, 0.485580f, 0.405103f, 0.335052f, 0.509858f, 0.764852f, 0.764527f, 0.382572f, 0.962121f, | ||||
|     0.296145f, 0.602766f, 0.169683f, 0.750371f, 0.993936f, 0.914704f, 0.199342f, 0.858098f, 0.617198f, 0.219334f, 0.167574f, 0.305204f, 0.960773f, 0.537944f, | ||||
|     0.245441f, 0.787276f, 0.968920f, 0.980918f, 0.615237f, 0.355165f, 0.480441f, 0.304282f, 0.961229f, 0.639195f, 0.017776f, 0.836153f | ||||
|     }); | ||||
|     auto threshold = NDArrayFactory::create<float>(0.5f); | ||||
|     auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141}); | ||||
| 
 | ||||
|     sd::ops::compare_and_bitpack op; | ||||
|     auto result = op.evaluate({&x, &threshold}, {}, {}, {}); | ||||
|     auto output = result.at(0); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, result.status()); | ||||
|     ASSERT_TRUE(exp.isSameShape(output)); | ||||
|     ASSERT_TRUE(exp.equalsTo(output)); | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test2) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<bool>('c', {2, 3, 16}, { | ||||
|         true, false, true, false, false, false, false, false, true, | ||||
|         true, true, true, true, false, false, false, true, false, | ||||
|         true, false, false, false, true, true, false, true, true, | ||||
|         true, false, true, true, false, true, true, false, true, | ||||
|         true, true, false, true, false, false, false, false, true, | ||||
|         true, true, false, false, false, false, false, true, true, | ||||
|         true, false, true, true, true, false, false, true, false, | ||||
|         false, false, true, true, true, false, true, false, true, | ||||
|         false, true, true, true, false, true, true, false, false, | ||||
|         false, true, true, false, true, true, true, true, false, | ||||
|         false, false, true, true, false, true | ||||
|     }); | ||||
|     //threshold is ignored here ,actually
 | ||||
|     auto threshold = NDArrayFactory::create<bool>(true); | ||||
|     auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141}); | ||||
| 
 | ||||
|     sd::ops::compare_and_bitpack op; | ||||
|     auto result = op.evaluate({&x, &threshold}, {}, {}, {}); | ||||
|     auto output = result.at(0); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, result.status()); | ||||
|     ASSERT_TRUE(exp.isSameShape(output)); | ||||
|     ASSERT_TRUE(exp.equalsTo(output)); | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test3) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 16}); | ||||
|     auto threshold = NDArrayFactory::create<float>(0.5f); | ||||
|     auto exp = NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2}); | ||||
| 
 | ||||
|     sd::ops::compare_and_bitpack op; | ||||
|     auto result = op.evaluate({&x, &threshold}, {}, {}, {}); | ||||
|     auto output = result.at(0); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, result.status()); | ||||
|     output->printShapeInfo("output"); | ||||
|     ASSERT_TRUE(exp.isSameShape(output)); | ||||
|     ASSERT_TRUE(exp.equalsTo(output)); | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test4) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13}); | ||||
|     auto threshold = NDArrayFactory::create<float>(0.5f); | ||||
|     sd::ops::compare_and_bitpack op;  | ||||
| 
 | ||||
|     ASSERT_THROW(op.evaluate({&x, &threshold}, {}, {}, {}), std::invalid_argument);  | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test5) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 13}); | ||||
|     auto threshold = NDArrayFactory::create<float>(0.5f); | ||||
|     auto out =  NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 1}); | ||||
|     sd::ops::compare_and_bitpack op;  | ||||
| 
 | ||||
|     ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::invalid_argument);  | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test6) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 0, 3, 8}); | ||||
|     auto threshold = NDArrayFactory::create<float>(0.5f); | ||||
|     auto out =  NDArrayFactory::create<uint8_t>('c', {2, 0, 3, 2}); | ||||
|     sd::ops::compare_and_bitpack op;  | ||||
|     //shape mismatch throws runtime error
 | ||||
|     ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::runtime_error);  | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test7) { | ||||
|     constexpr int pp = 32*32*16; | ||||
|     constexpr int s1 = 3;  | ||||
|     constexpr int t1 = 8; | ||||
|     std::vector<Nd4jLong> shape1 = {pp};  | ||||
|     std::vector<Nd4jLong> strides1 = {s1}; | ||||
|     std::vector<Nd4jLong> shape2 = {pp/8};  | ||||
|     std::vector<Nd4jLong> strides2 = {t1}; | ||||
|     ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, s1); | ||||
|     ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, t1); | ||||
|     auto x = NDArrayFactory::create(desc1); | ||||
|     auto output = NDArrayFactory::create(desc2); | ||||
|     auto exp =  NDArrayFactory::create(desc2); | ||||
|     auto threshold = NDArrayFactory::create<bool>(true); | ||||
|     auto buff = x.bufferAsT<bool>(); | ||||
| 	uint8_t *expBuff = exp.bufferAsT<uint8_t>(); | ||||
|     //generate test
 | ||||
|     for(int l=0;l<pp; l+=8){ | ||||
|                 uint8_t test =  rand() % 255; | ||||
|                 expBuff[l/8*t1] = test; | ||||
|                 auto buffP = &(buff[l*s1]); | ||||
|                 buffP[0] = test & (1<<7); | ||||
|                 buffP[1*s1] = test & (1<<6); | ||||
|                 buffP[2*s1] = test & (1<<5); | ||||
|                 buffP[3*s1] = test & (1<<4); | ||||
|                 buffP[4*s1] = test & (1<<3); | ||||
|                 buffP[5*s1] = test & (1<<2); | ||||
|                 buffP[6*s1] = test & (1<<1); | ||||
|                 buffP[7*s1] = test & 1; | ||||
|     } | ||||
|     //explicit sync to device
 | ||||
|     x.tickWriteHost(); | ||||
|     exp.tickWriteHost(); | ||||
|     x.syncToDevice(); | ||||
|     exp.syncToDevice(); | ||||
| 
 | ||||
|     sd::ops::compare_and_bitpack op; | ||||
|     auto result = op.execute({&x, &threshold}, {&output}, {}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result); | ||||
|     ASSERT_TRUE(exp.isSameShape(&output)); | ||||
|     ASSERT_TRUE(exp.equalsTo(&output)); | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test8) { | ||||
|     constexpr int pp = 32; | ||||
|     constexpr int s1 = 2; | ||||
|     constexpr int s2 = (s1*pp) + 3; | ||||
|     constexpr int s3 = (s2*pp) + 4; | ||||
|     constexpr int t1 = 2; | ||||
|     constexpr int t2 = (t1*pp/8) + 3; | ||||
|     constexpr int t3 = (t2*pp) + 4; | ||||
|     std::vector<Nd4jLong> shape1 = {pp,pp,pp};  | ||||
|     std::vector<Nd4jLong> strides1 = {s3 , s2 , s1}; | ||||
|     std::vector<Nd4jLong> shape2 = {pp,pp,pp/8};  | ||||
|     std::vector<Nd4jLong> strides2 = {t3 , t2 , t1}; | ||||
|     ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, 0); | ||||
|     ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, 0); | ||||
|     auto x = NDArrayFactory::create(desc1); | ||||
|     auto output =  NDArrayFactory::create(desc2); | ||||
|     auto exp =  NDArrayFactory::create(desc2); | ||||
|     auto threshold = NDArrayFactory::create<bool>(true); | ||||
|     auto buff = x.bufferAsT<bool>(); | ||||
| 	uint8_t *expBuff = exp.bufferAsT<uint8_t>(); | ||||
|     //generate test
 | ||||
|     for(int i=0;i<pp;i++){ | ||||
|         for(int j=0;j<pp;j++){ | ||||
|             for(int l=0;l<pp; l+=8){ | ||||
|                 uint8_t test =  rand() % 255; | ||||
|                 expBuff[l/8*t1 + j*t2 + i *t3] = test; | ||||
|                 auto buffP = &(buff[j*s2 + i *s3 + l*s1]); | ||||
|                 buffP[0] = test & (1<<7); | ||||
|                 buffP[1*s1] = test & (1<<6); | ||||
|                 buffP[2*s1] = test & (1<<5); | ||||
|                 buffP[3*s1] = test & (1<<4); | ||||
|                 buffP[4*s1] = test & (1<<3); | ||||
|                 buffP[5*s1] = test & (1<<2); | ||||
|                 buffP[6*s1] = test & (1<<1); | ||||
|                 buffP[7*s1] = test & 1; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     //explicit sync to device
 | ||||
|     x.tickWriteHost(); | ||||
|     exp.tickWriteHost(); | ||||
|     x.syncToDevice(); | ||||
|     exp.syncToDevice(); | ||||
|     sd::ops::compare_and_bitpack op; | ||||
|     auto result = op.execute({&x, &threshold}, {&output}, {}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result); | ||||
|     ASSERT_TRUE(exp.isSameShape(&output)); | ||||
|     ASSERT_TRUE(exp.equalsTo(&output)); | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { | ||||
| 
 | ||||
| @ -1776,25 +1978,6 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); | ||||
|     auto threshold = NDArrayFactory::create<double>(2.0); | ||||
|     auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||||
|                                                                                 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); | ||||
| 
 | ||||
|     sd::ops::compare_and_bitpack op; | ||||
| 
 | ||||
|     auto result = op.evaluate({&x, &threshold}, {}, {}, {}); | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, result.status()); | ||||
|     auto output = result.at(0); | ||||
| //    output->printIndexedBuffer("Packed to uint8");
 | ||||
|     ASSERT_TRUE(exp.isSameShape(output)); | ||||
|     ASSERT_TRUE(exp.equalsTo(output)); | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user