Merge pull request #2 from KonduitAI/shugeo_bincast

[WIP] Shugeo bitcast
master
raver119 2019-10-03 14:27:24 +03:00 committed by GitHub
commit d2e98564d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 219 additions and 0 deletions

View File

@ -0,0 +1,92 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitcast)
#include <array/DataTypeUtils.h>
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
// when empty - nothing to do
if(input->isEmpty()){
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
return Status::OK();
}
// buffers for both input and output should be equals
DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType());
*(output->dataBuffer()) = buf;
return Status::OK();
}
DECLARE_SYN(BitCast, bitcast);
DECLARE_SHAPE_FN(bitcast) {
auto inShape = inputShape->at(0);
auto inputRank = shape::rank(inShape);
auto it = INT_ARG(0);
DataType newType = DataTypeUtils::fromInt(it);
DataType oldType = ArrayOptions::dataType(inShape);
// correct output shape to conform with output data type
auto inputSize = DataTypeUtils::sizeOf(oldType);
auto outputSize = DataTypeUtils::sizeOf(newType);
if (shape::length(inShape) == 0)
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
if (inputSize == outputSize) {
// only type should be changed
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
}
else if (inputSize > outputSize) {
// range of output increased by 1 with inputSize / outputSize as last dimension
std::vector<Nd4jLong> shapeOf(inputRank + 1);
int i;
for (i = 0; i < inputRank; ++i) {
shapeOf[i] = inShape[i + 1];
}
shapeOf[i] = inputSize / outputSize;
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
return SHAPELIST(outputShape);
}
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
std::vector<Nd4jLong> shapeOf(inputRank - 1);
for (auto i = 0; i < shapeOf.size(); ++i) {
shapeOf[i] = inShape[i + 1];
}
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
return SHAPELIST(outputShape);
}
DECLARE_TYPES(bitcast) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::ANY);
}
}
}
#endif

View File

@ -0,0 +1,60 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/headers/datatypes.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector());
BROADCAST_CHECK_EMPTY(x, y, (&z0));
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
bitcast res;
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false);
if (tZ != &z0) {
delete tZ;
}
return status;
}
DECLARE_TYPES(compare_and_bitpack) {
getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, DataType::ANY)
->setAllowedOutputTypes(0, DataType::UINT8);
}
DECLARE_SHAPE_FN(compare_and_bitpack) {
auto inShape = inputShape->at(0);
DataType newType = DataType::UINT8;
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
}
}
}

View File

@ -99,6 +99,14 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_cast)
DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1);
#endif
/**
* This operation change type of input and modified shape of output to conform with given data type
*
* all as above op
* */
#if NOT_EXCLUDED(OP_bitcast)
DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1);
#endif
}
}

View File

@ -1731,6 +1731,20 @@ namespace nd4j {
DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2);
#endif
/**
* compare_and_bitpack - compare with greater and pack result with uint8
*
* input params:
* 0 - NDArray (input)
* 1 - 0D Tensor - threshold
*
*
* output:
* 0 - NDArray with the same shape as input and type uint8
*/
#if NOT_EXCLUDED(OP_compare_and_bitpack)
DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
#endif
}
}

View File

@ -228,6 +228,32 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
auto x = NDArrayFactory::create<float>('c', {2, 2, 2});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
x.linspace(1.);
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
// out->printIndexedBuffer("Casted result");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
auto x = NDArrayFactory::create<float>('c', {2, 4});
auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25,
0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5});
x.linspace(1.);
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
out->printIndexedBuffer("Casted result");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});

View File

@ -2387,6 +2387,25 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto threshold = NDArrayFactory::create<double>(2.0);
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1});
nd4j::ops::compare_and_bitpack op;
auto result = op.execute({&x, &threshold}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
// output->printIndexedBuffer("Packed to uint8");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {