Merge pull request #2 from KonduitAI/shugeo_bincast
[WIP] Shugeo bitcast
This commit is contained in:
		
						commit
						d2e98564d4
					
				
							
								
								
									
										92
									
								
								libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,92 @@ | ||||
| /*******************************************************************************
 | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //
 | ||||
| // @author George A. Shulinok <sgazeos@gmail.com>
 | ||||
| //
 | ||||
| 
 | ||||
| #include <op_boilerplate.h> | ||||
| #if NOT_EXCLUDED(OP_bitcast) | ||||
| 
 | ||||
| #include <array/DataTypeUtils.h> | ||||
| #include <ops/declarable/CustomOperations.h> | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) { | ||||
|             auto input = INPUT_VARIABLE(0); | ||||
|             auto output = OUTPUT_VARIABLE(0); | ||||
|             // when empty - nothing to do
 | ||||
|             if(input->isEmpty()){ | ||||
|                 REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); | ||||
|                 return Status::OK(); | ||||
|             } | ||||
|             // buffers for both input and output should be equals
 | ||||
|             DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType()); | ||||
|             *(output->dataBuffer()) = buf; | ||||
| 
 | ||||
|             return Status::OK(); | ||||
|         } | ||||
|         DECLARE_SYN(BitCast, bitcast); | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(bitcast) { | ||||
|             auto inShape = inputShape->at(0); | ||||
|             auto inputRank = shape::rank(inShape); | ||||
|             auto it = INT_ARG(0); | ||||
|             DataType newType = DataTypeUtils::fromInt(it); | ||||
|             DataType oldType = ArrayOptions::dataType(inShape); | ||||
|             // correct output shape to conform with output data type
 | ||||
|             auto inputSize = DataTypeUtils::sizeOf(oldType); | ||||
|             auto outputSize = DataTypeUtils::sizeOf(newType); | ||||
| 
 | ||||
|             if (shape::length(inShape) == 0) | ||||
|                 return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); | ||||
| 
 | ||||
|             if (inputSize == outputSize) { | ||||
|                 // only type should be changed
 | ||||
|                 return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); | ||||
|             } | ||||
|             else if (inputSize > outputSize) { | ||||
|                 // range of output increased by 1 with inputSize / outputSize as last dimension
 | ||||
|                 std::vector<Nd4jLong> shapeOf(inputRank + 1); | ||||
|                 int i; | ||||
|                 for (i = 0; i < inputRank; ++i) { | ||||
|                     shapeOf[i] = inShape[i + 1]; | ||||
|                 } | ||||
|                 shapeOf[i] = inputSize / outputSize; | ||||
|                 auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf); | ||||
|                 return SHAPELIST(outputShape); | ||||
|             } | ||||
|             REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1)); | ||||
|             std::vector<Nd4jLong> shapeOf(inputRank - 1); | ||||
| 
 | ||||
|             for (auto i = 0; i < shapeOf.size(); ++i) { | ||||
|                 shapeOf[i] = inShape[i + 1]; | ||||
|             } | ||||
| 
 | ||||
|             auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf); | ||||
|             return SHAPELIST(outputShape); | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_TYPES(bitcast) { | ||||
|             getOpDescriptor() | ||||
|                     ->setAllowedInputTypes(nd4j::DataType::ANY) | ||||
|                     ->setAllowedOutputTypes(nd4j::DataType::ANY); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| @ -0,0 +1,60 @@ | ||||
| /*******************************************************************************
 | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //
 | ||||
| // @author sgazeos@gmail.com
 | ||||
| //
 | ||||
| 
 | ||||
| #include <ops/declarable/generic/helpers/BroadcastHelper.h> | ||||
| #include <ops/declarable/headers/parity_ops.h> | ||||
| #include <ops/declarable/headers/datatypes.h> | ||||
| #include <NDArrayFactory.h> | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) { | ||||
|             auto x = INPUT_VARIABLE(0); | ||||
|             auto y = INPUT_VARIABLE(1); | ||||
|             auto z = OUTPUT_VARIABLE(0); | ||||
|             auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector()); | ||||
|             BROADCAST_CHECK_EMPTY(x, y, (&z0)); | ||||
|              | ||||
|             auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); | ||||
|             bitcast res; | ||||
|             auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false); | ||||
|             if (tZ != &z0) { | ||||
|                 delete tZ; | ||||
|             } | ||||
|              | ||||
|             return status; | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_TYPES(compare_and_bitpack) { | ||||
|             getOpDescriptor() | ||||
|                     ->setAllowedInputTypes(0, DataType::ANY) | ||||
|                     ->setAllowedInputTypes(1, DataType::ANY) | ||||
|                     ->setAllowedOutputTypes(0, DataType::UINT8); | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(compare_and_bitpack) { | ||||
|             auto inShape = inputShape->at(0); | ||||
|             DataType newType = DataType::UINT8; | ||||
| 
 | ||||
|             return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| } | ||||
| @ -99,6 +99,14 @@ namespace nd4j { | ||||
|         #if NOT_EXCLUDED(OP_cast) | ||||
|         DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1); | ||||
|         #endif | ||||
|         /**
 | ||||
|          * This operation change type of input and modified shape of output to conform with given data type | ||||
|          * | ||||
|          * all as above op | ||||
|          * */ | ||||
|         #if NOT_EXCLUDED(OP_bitcast) | ||||
|                 DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); | ||||
|         #endif | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -1731,6 +1731,20 @@ namespace nd4j { | ||||
|         DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|          * compare_and_bitpack - compare with greater and pack result with uint8  | ||||
|          * | ||||
|          * input params: | ||||
|          *    0 - NDArray (input) | ||||
|          *    1 - 0D Tensor - threshold | ||||
|          * | ||||
|          * | ||||
|          * output: | ||||
|          *    0 - NDArray with the same shape as input and type uint8 | ||||
|          */ | ||||
|         #if NOT_EXCLUDED(OP_compare_and_bitpack) | ||||
|         DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); | ||||
|         #endif | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -228,6 +228,32 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| TEST_F(DeclarableOpsTests15, Test_BitCast_1) { | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 2, 2}); | ||||
|     auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 }); | ||||
|     x.linspace(1.); | ||||
|     nd4j::ops::bitcast op; | ||||
|     auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     auto out = result->at(0); | ||||
| //    out->printIndexedBuffer("Casted result");
 | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests15, Test_BitCast_2) { | ||||
|     auto x = NDArrayFactory::create<float>('c', {2, 4}); | ||||
|     auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0, 1.875, 0, 2.,    0, 2.125, 0,  2.25, | ||||
|                                                                               0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5}); | ||||
|     x.linspace(1.); | ||||
|     nd4j::ops::bitcast op; | ||||
|     auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     auto out = result->at(0); | ||||
|     out->printIndexedBuffer("Casted result"); | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { | ||||
|     auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64}); | ||||
|  | ||||
| @ -2387,6 +2387,25 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); | ||||
|     auto threshold = NDArrayFactory::create<double>(2.0); | ||||
|     auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||||
|                                                                                 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); | ||||
| 
 | ||||
|     nd4j::ops::compare_and_bitpack op; | ||||
| 
 | ||||
|     auto result = op.execute({&x, &threshold}, {}, {}, {}); | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||
|     auto output = result->at(0); | ||||
| //    output->printIndexedBuffer("Packed to uint8");
 | ||||
|     ASSERT_TRUE(exp.isSameShape(output)); | ||||
|     ASSERT_TRUE(exp.equalsTo(output)); | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user