Implemented compare_and_bitpack op.
This commit is contained in:
		
							parent
							
								
									75ad3c8153
								
							
						
					
					
						commit
						130ee25682
					
				@ -0,0 +1,60 @@
 | 
				
			|||||||
 | 
					/*******************************************************************************
 | 
				
			||||||
 | 
					 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 * under the License.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 ******************************************************************************/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// @author sgazeos@gmail.com
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <ops/declarable/generic/helpers/BroadcastHelper.h>
 | 
				
			||||||
 | 
					#include <ops/declarable/headers/parity_ops.h>
 | 
				
			||||||
 | 
					#include <ops/declarable/headers/datatypes.h>
 | 
				
			||||||
 | 
					#include <NDArrayFactory.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace nd4j {
 | 
				
			||||||
 | 
					    namespace ops {
 | 
				
			||||||
 | 
					        CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) {
 | 
				
			||||||
 | 
					            auto x = INPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					            auto y = INPUT_VARIABLE(1);
 | 
				
			||||||
 | 
					            auto z = OUTPUT_VARIABLE(0);
 | 
				
			||||||
 | 
					            auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector());
 | 
				
			||||||
 | 
					            BROADCAST_CHECK_EMPTY(x, y, (&z0));
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
 | 
				
			||||||
 | 
					            bitcast res;
 | 
				
			||||||
 | 
					            auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false);
 | 
				
			||||||
 | 
					            if (tZ != &z0) {
 | 
				
			||||||
 | 
					                delete tZ;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            return status;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DECLARE_TYPES(compare_and_bitpack) {
 | 
				
			||||||
 | 
					            getOpDescriptor()
 | 
				
			||||||
 | 
					                    ->setAllowedInputTypes(0, DataType::ANY)
 | 
				
			||||||
 | 
					                    ->setAllowedInputTypes(1, DataType::ANY)
 | 
				
			||||||
 | 
					                    ->setAllowedOutputTypes(0, DataType::UINT8);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DECLARE_SHAPE_FN(compare_and_bitpack) {
 | 
				
			||||||
 | 
					            auto inShape = inputShape->at(0);
 | 
				
			||||||
 | 
					            DataType newType = DataType::UINT8;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1731,6 +1731,20 @@ namespace nd4j {
 | 
				
			|||||||
        DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2);
 | 
					        DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2);
 | 
				
			||||||
        #endif
 | 
					        #endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /**
 | 
				
			||||||
 | 
					         * compare_and_bitpack - compare with greater and pack result with uint8 
 | 
				
			||||||
 | 
					         *
 | 
				
			||||||
 | 
					         * input params:
 | 
				
			||||||
 | 
					         *    0 - NDArray (input)
 | 
				
			||||||
 | 
					         *    1 - 0D Tensor - threshold
 | 
				
			||||||
 | 
					         *
 | 
				
			||||||
 | 
					         *
 | 
				
			||||||
 | 
					         * output:
 | 
				
			||||||
 | 
					         *    0 - NDArray with the same shape as input and type uint8
 | 
				
			||||||
 | 
					         */
 | 
				
			||||||
 | 
					        #if NOT_EXCLUDED(OP_compare_and_bitpack)
 | 
				
			||||||
 | 
					        DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
 | 
				
			||||||
 | 
					        #endif
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -2387,6 +2387,25 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
 | 
				
			|||||||
    delete result;
 | 
					    delete result;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
 | 
					TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
 | 
				
			||||||
 | 
					    auto threshold = NDArrayFactory::create<double>(2.0);
 | 
				
			||||||
 | 
					    auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 | 
				
			||||||
 | 
					                                                                                0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    nd4j::ops::compare_and_bitpack op;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto result = op.execute({&x, &threshold}, {}, {}, {});
 | 
				
			||||||
 | 
					    ASSERT_EQ(ND4J_STATUS_OK, result->status());
 | 
				
			||||||
 | 
					    auto output = result->at(0);
 | 
				
			||||||
 | 
					//    output->printIndexedBuffer("Packed to uint8");
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.isSameShape(output));
 | 
				
			||||||
 | 
					    ASSERT_TRUE(exp.equalsTo(output));
 | 
				
			||||||
 | 
					    delete result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
////////////////////////////////////////////////////////////////////////////////
 | 
					////////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
 | 
					TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user