- rewrite broadcast_dynamic_shape and delete corresponding helpers (#194)

Signed-off-by: Yurii <yurii@skymind.io>
master
Yurii Shyrma 2019-08-29 20:38:02 +03:00 committed by raver119
parent 0463ee4eba
commit 5395d4fbe5
5 changed files with 73 additions and 330 deletions

View File

@ -15,54 +15,79 @@
******************************************************************************/
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_broadcast_dynamic_shape)
//#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/bds.h>
namespace nd4j {
namespace ops {
DECLARE_TYPES(broadcast_dynamic_shape) {
getOpDescriptor()
->setAllowedOutputTypes({ALL_INTS})
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
namespace ops {
CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) {
auto x_shape = INPUT_VARIABLE(0);
auto y_shape = INPUT_VARIABLE(1);
REQUIRE_TRUE(shape::isVector(x_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The first argument should be a vector");
REQUIRE_TRUE(shape::isVector(y_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The second argument should be a vector");
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) {
auto output = OUTPUT_VARIABLE(0);
return helpers::bdsFunctor(block.launchContext(), x_shape, y_shape, output);
}
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
DECLARE_SHAPE_FN(broadcast_dynamic_shape) {
auto shapeList = SHAPELIST();
auto theFirst = inputShape->at(0);
auto theSecond = inputShape->at(1);
auto z = OUTPUT_VARIABLE(0);
auto theFirstLen = shape::sizeAt(theFirst, -1);
auto theSecondLen = shape::sizeAt(theSecond, -1);
REQUIRE_TRUE(x->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have rank = 1, but got %i instead!", x->rankOf());
REQUIRE_TRUE(y->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have rank = 1, but got %i instead!", y->rankOf());
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the same integer type !");
auto shapeLength = nd4j::math::nd4j_max(theFirstLen, theSecondLen);
// contract shapeInfos, neglect and don't fill strides, ews, order
// shapes are of interest only
std::vector<Nd4jLong> xShapeInfo(shape::shapeInfoLength(x->lengthOf()));
std::vector<Nd4jLong> yShapeInfo(shape::shapeInfoLength(y->lengthOf()));
auto newshape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shapeLength, ArrayOptions::dataType(theFirst));
shapeList->push_back(newshape);
return shapeList;
}
// fill rank and data type
xShapeInfo[0] = x->lengthOf();
yShapeInfo[0] = y->lengthOf();
ArrayOptions::setDataType(xShapeInfo.data(), nd4j::DataType::INT64); // fill with some data type, it doesn't matter what type exactly to choose
ArrayOptions::setDataType(yShapeInfo.data(), nd4j::DataType::INT64);
}
for (Nd4jLong i = 0; i < x->lengthOf(); ++i)
xShapeInfo[i + 1] = x->e<Nd4jLong>(i);
for (Nd4jLong i = 0; i < y->lengthOf(); ++i)
yShapeInfo[i + 1] = y->e<Nd4jLong>(i);
Nd4jLong* poinerOnOutShapeInfo = nullptr;
const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo(xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, block.launchContext()->getWorkspace());
REQUIRE_TRUE(isBroadcastPossible, 0, "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s and %s are not suitable for broadcast operation !", ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), ShapeUtils::shapeAsString(yShapeInfo.data()).c_str());
for (Nd4jLong i = 0; i < z->lengthOf(); ++i)
z->p<Nd4jLong>(i, poinerOnOutShapeInfo[i + 1]);
return Status::OK();
}
DECLARE_TYPES(broadcast_dynamic_shape) {
getOpDescriptor()
->setAllowedOutputTypes({ALL_INTS})
->setAllowedInputTypes({ALL_INTS});
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(broadcast_dynamic_shape) {
const int xRank = INPUT_VARIABLE(0)->lengthOf();
const int yRank = INPUT_VARIABLE(1)->lengthOf();
const int maxRank = xRank > yRank ? xRank : yRank;
auto outputShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxRank, ArrayOptions::dataType(inputShape->at(0)));
return SHAPELIST(outputShapeInfo);
}
}
}
#endif

View File

@ -1,33 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#ifndef __BDS_H_HELPERS__
#define __BDS_H_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output);
}
}
}
#endif

View File

@ -1,72 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/bds.h>
#include <Status.h>
namespace nd4j {
namespace ops {
namespace helpers {
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
// lenght are equals
if (x_shape->lengthOf() == y_shape->lengthOf()) {
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
output->assign(greater);
}
else {
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
output->assign(greater);
auto lastG = greater->lengthOf() - 1;
auto lastL = lesser->lengthOf() - 1;
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
output->p(lastG, lesser->e(lastL));
}
}
else {
//int e = 0, x = 0, y = 0;
Nd4jLong xLen = x_shape->lengthOf();
Nd4jLong yLen = y_shape->lengthOf();
Nd4jLong zLen = output->lengthOf();
Nd4jLong borderLen = nd4j::math::nd4j_min(xLen, yLen);
for (Nd4jLong e = 0; e < zLen; e++) {
Nd4jLong val;
if (e < borderLen) {
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(e), y_shape->e<Nd4jLong>(e));
} else if (e < xLen) {
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(e), y_shape->e<Nd4jLong>(yLen - 1));
} else {
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(xLen - 1), y_shape->e<Nd4jLong>(e));
}
output->p(e, val);
}
}
return Status::OK();
}
}
}
}

View File

@ -1,113 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/bds.h>
#include <Status.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void bdsLoopKernel(void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
__shared__ T const* x;
__shared__ T const* y;
__shared__ T* z;
__shared__ bool speedWay;
//__shared__ int indexX, indexY;
__shared__ Nd4jLong xLen, yLen, outputLen;
if (threadIdx.x == 0) {
x = reinterpret_cast<T const*>(inputX);
y = reinterpret_cast<T const*>(inputY);
z = reinterpret_cast<T*>(output);
xLen = shape::length(inputXshape);
yLen = shape::length(inputYshape);
outputLen = shape::length(outputShape);
speedWay = true;
speedWay = speedWay && (shape::elementWiseStride(inputXshape) == 1);
speedWay = speedWay && (shape::elementWiseStride(inputYshape) == 1);
speedWay = speedWay && (shape::elementWiseStride(outputShape) == 1);
}
__syncthreads();
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for (int e = tid; e < outputLen; e += step) {
T val;
if (speedWay) {
if (e < nd4j::math::nd4j_min(yLen, xLen)) {
val = nd4j::math::nd4j_max(x[e], y[e]);
} else if (e < xLen) {
val = nd4j::math::nd4j_max(x[e], y[yLen - 1]);
} else {
val = nd4j::math::nd4j_max(x[xLen - 1], y[e]);
}
z[e] = val;
}
else {
auto xIndex = e < xLen?shape::getIndexOffset(e, inputXshape, xLen):shape::getIndexOffset(xLen, inputXshape, xLen);
auto yIndex = e < yLen?shape::getIndexOffset(e, inputYshape, yLen):shape::getIndexOffset(yLen - 1, inputYshape, yLen);
auto zIndex = shape::getIndexOffset(e, outputShape, outputLen);
z[zIndex] = nd4j::math::nd4j_max(x[xIndex], y[yIndex]);
}
}
}
template <typename T>
static void bdsLoopH(cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
bdsLoopKernel<T><<<1, 256, 512, *stream>>>(inputX, inputXshape, inputY, inputYshape, output, outputShape);
}
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
//int e = 0, x = 0, y = 0;
NDArray::prepareSpecialUse({output}, {x_shape, y_shape});
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
x_shape->syncToHost(); y_shape->syncToHost();
if (x_shape->lengthOf() == y_shape->lengthOf()) {
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
output->assign(greater);
}
else {
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
output->assign(greater);
auto lastG = greater->lengthOf() - 1;
auto lastL = lesser->lengthOf() - 1;
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
output->p(lastG, lesser->e(lastL));
output->syncToDevice();
}
}
else {
//bdsLoopH(context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), y->getSpecialBuffer(), y->getSpecialShape(), output->specialBuffer(), output->specialShapeInfo())
BUILD_SINGLE_SELECTOR(output->dataType(), bdsLoopH, (context->getCudaStream(), x_shape->getSpecialBuffer(), x_shape->getSpecialShapeInfo(), y_shape->getSpecialBuffer(), y_shape->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES);
}
NDArray::registerSpecialUse({output}, {x_shape, y_shape});
return Status::OK();
return Status::OK();
}
}
}
}

View File

@ -1086,13 +1086,11 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {
auto y = NDArrayFactory::create<int>({ 2, 1, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<int>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
auto res = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1107,7 +1105,6 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
auto y = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
@ -1122,17 +1119,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 2, 2} );
auto x = NDArrayFactory::create<int>( {2, 2, 2} );
auto y = NDArrayFactory::create<Nd4jLong>({ 2, 1});
auto y = NDArrayFactory::create<int>({2, 1});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
auto exp = NDArrayFactory::create<int>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
auto res = op.execute({&x, &y}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
@ -1145,9 +1140,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 1} );
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, { 4,});
// ------------------------------------
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, {4});
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4});
@ -1161,49 +1154,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_5) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_5) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 5");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
@ -1211,16 +1162,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 6");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
@ -1233,16 +1180,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
auto y = NDArrayFactory::create<Nd4jLong>({2, 4, 1});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 7");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
@ -1255,19 +1198,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) {
auto y = NDArrayFactory::create<int>('c', {1}, {4});
// ------------------------------------
auto z = NDArrayFactory::create<int>('c', {1});
auto exp = NDArrayFactory::create<int>('c', {1}, {4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 8");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(exp.equalsTo(z));
}
/////////////////////////////////////////////////////////////////////////////////
@ -1277,19 +1216,16 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) {
auto y = NDArrayFactory::create<int>('c', {1}, {1});
// ------------------------------------
auto z = NDArrayFactory::create<Nd4jLong>('c', {2});
auto exp = NDArrayFactory::create<int>('c', {2}, {2,2});
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {2,2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
res->at(0)->printIndexedBuffer("Output SGO 9");
exp.printIndexedBuffer("Expect9");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
ASSERT_EQ(ND4J_STATUS_OK, status);
// ASSERT_TRUE(exp.equalsTo(z));
delete res;
}
////////////////////////////////////////////////////////////////////////////////