- rewrite broadcast_dynamic_shape and delete corresponding helpers (#194)
Signed-off-by: Yurii <yurii@skymind.io>master
parent
0463ee4eba
commit
5395d4fbe5
|
@ -15,51 +15,76 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_broadcast_dynamic_shape)
|
#if NOT_EXCLUDED(OP_broadcast_dynamic_shape)
|
||||||
|
|
||||||
//#include <ops/declarable/headers/parity_ops.h>
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <ops/declarable/helpers/bds.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have rank = 1, but got %i instead!", x->rankOf());
|
||||||
|
REQUIRE_TRUE(y->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have rank = 1, but got %i instead!", y->rankOf());
|
||||||
|
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the same integer type !");
|
||||||
|
|
||||||
|
// contract shapeInfos, neglect and don't fill strides, ews, order
|
||||||
|
// shapes are of interest only
|
||||||
|
std::vector<Nd4jLong> xShapeInfo(shape::shapeInfoLength(x->lengthOf()));
|
||||||
|
std::vector<Nd4jLong> yShapeInfo(shape::shapeInfoLength(y->lengthOf()));
|
||||||
|
|
||||||
|
// fill rank and data type
|
||||||
|
xShapeInfo[0] = x->lengthOf();
|
||||||
|
yShapeInfo[0] = y->lengthOf();
|
||||||
|
ArrayOptions::setDataType(xShapeInfo.data(), nd4j::DataType::INT64); // fill with some data type, it doesn't matter what type exactly to choose
|
||||||
|
ArrayOptions::setDataType(yShapeInfo.data(), nd4j::DataType::INT64);
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < x->lengthOf(); ++i)
|
||||||
|
xShapeInfo[i + 1] = x->e<Nd4jLong>(i);
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < y->lengthOf(); ++i)
|
||||||
|
yShapeInfo[i + 1] = y->e<Nd4jLong>(i);
|
||||||
|
|
||||||
|
Nd4jLong* poinerOnOutShapeInfo = nullptr;
|
||||||
|
|
||||||
|
const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo(xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, block.launchContext()->getWorkspace());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(isBroadcastPossible, 0, "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s and %s are not suitable for broadcast operation !", ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), ShapeUtils::shapeAsString(yShapeInfo.data()).c_str());
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < z->lengthOf(); ++i)
|
||||||
|
z->p<Nd4jLong>(i, poinerOnOutShapeInfo[i + 1]);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(broadcast_dynamic_shape) {
|
DECLARE_TYPES(broadcast_dynamic_shape) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedOutputTypes({ALL_INTS})
|
->setAllowedOutputTypes({ALL_INTS})
|
||||||
->setAllowedInputTypes({ALL_INTS})
|
->setAllowedInputTypes({ALL_INTS});
|
||||||
->setSameMode(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) {
|
|
||||||
auto x_shape = INPUT_VARIABLE(0);
|
|
||||||
auto y_shape = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(shape::isVector(x_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The first argument should be a vector");
|
|
||||||
REQUIRE_TRUE(shape::isVector(y_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The second argument should be a vector");
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return helpers::bdsFunctor(block.launchContext(), x_shape, y_shape, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_SHAPE_FN(broadcast_dynamic_shape) {
|
DECLARE_SHAPE_FN(broadcast_dynamic_shape) {
|
||||||
auto shapeList = SHAPELIST();
|
|
||||||
|
|
||||||
auto theFirst = inputShape->at(0);
|
const int xRank = INPUT_VARIABLE(0)->lengthOf();
|
||||||
auto theSecond = inputShape->at(1);
|
const int yRank = INPUT_VARIABLE(1)->lengthOf();
|
||||||
|
|
||||||
auto theFirstLen = shape::sizeAt(theFirst, -1);
|
const int maxRank = xRank > yRank ? xRank : yRank;
|
||||||
auto theSecondLen = shape::sizeAt(theSecond, -1);
|
|
||||||
|
|
||||||
auto shapeLength = nd4j::math::nd4j_max(theFirstLen, theSecondLen);
|
auto outputShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxRank, ArrayOptions::dataType(inputShape->at(0)));
|
||||||
|
|
||||||
auto newshape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shapeLength, ArrayOptions::dataType(theFirst));
|
return SHAPELIST(outputShapeInfo);
|
||||||
shapeList->push_back(newshape);
|
|
||||||
return shapeList;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author sgazeos@gmail.com
|
|
||||||
//
|
|
||||||
#ifndef __BDS_H_HELPERS__
|
|
||||||
#define __BDS_H_HELPERS__
|
|
||||||
#include <op_boilerplate.h>
|
|
||||||
#include <NDArray.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
|
@ -1,72 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author GS <sgazeos@gmail.com>
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/helpers/bds.h>
|
|
||||||
#include <Status.h>
|
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
|
|
||||||
|
|
||||||
|
|
||||||
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
|
||||||
// lenght are equals
|
|
||||||
if (x_shape->lengthOf() == y_shape->lengthOf()) {
|
|
||||||
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
|
|
||||||
output->assign(greater);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
|
|
||||||
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
|
|
||||||
output->assign(greater);
|
|
||||||
auto lastG = greater->lengthOf() - 1;
|
|
||||||
auto lastL = lesser->lengthOf() - 1;
|
|
||||||
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
|
|
||||||
output->p(lastG, lesser->e(lastL));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
//int e = 0, x = 0, y = 0;
|
|
||||||
Nd4jLong xLen = x_shape->lengthOf();
|
|
||||||
Nd4jLong yLen = y_shape->lengthOf();
|
|
||||||
Nd4jLong zLen = output->lengthOf();
|
|
||||||
Nd4jLong borderLen = nd4j::math::nd4j_min(xLen, yLen);
|
|
||||||
for (Nd4jLong e = 0; e < zLen; e++) {
|
|
||||||
Nd4jLong val;
|
|
||||||
if (e < borderLen) {
|
|
||||||
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(e), y_shape->e<Nd4jLong>(e));
|
|
||||||
} else if (e < xLen) {
|
|
||||||
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(e), y_shape->e<Nd4jLong>(yLen - 1));
|
|
||||||
} else {
|
|
||||||
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(xLen - 1), y_shape->e<Nd4jLong>(e));
|
|
||||||
}
|
|
||||||
|
|
||||||
output->p(e, val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,113 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author GS <sgazeos@gmail.com>
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/helpers/bds.h>
|
|
||||||
#include <Status.h>
|
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void bdsLoopKernel(void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
|
|
||||||
__shared__ T const* x;
|
|
||||||
__shared__ T const* y;
|
|
||||||
__shared__ T* z;
|
|
||||||
__shared__ bool speedWay;
|
|
||||||
//__shared__ int indexX, indexY;
|
|
||||||
__shared__ Nd4jLong xLen, yLen, outputLen;
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
x = reinterpret_cast<T const*>(inputX);
|
|
||||||
y = reinterpret_cast<T const*>(inputY);
|
|
||||||
z = reinterpret_cast<T*>(output);
|
|
||||||
xLen = shape::length(inputXshape);
|
|
||||||
yLen = shape::length(inputYshape);
|
|
||||||
outputLen = shape::length(outputShape);
|
|
||||||
speedWay = true;
|
|
||||||
speedWay = speedWay && (shape::elementWiseStride(inputXshape) == 1);
|
|
||||||
speedWay = speedWay && (shape::elementWiseStride(inputYshape) == 1);
|
|
||||||
speedWay = speedWay && (shape::elementWiseStride(outputShape) == 1);
|
|
||||||
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
auto step = blockDim.x * gridDim.x;
|
|
||||||
for (int e = tid; e < outputLen; e += step) {
|
|
||||||
T val;
|
|
||||||
if (speedWay) {
|
|
||||||
if (e < nd4j::math::nd4j_min(yLen, xLen)) {
|
|
||||||
val = nd4j::math::nd4j_max(x[e], y[e]);
|
|
||||||
} else if (e < xLen) {
|
|
||||||
val = nd4j::math::nd4j_max(x[e], y[yLen - 1]);
|
|
||||||
} else {
|
|
||||||
val = nd4j::math::nd4j_max(x[xLen - 1], y[e]);
|
|
||||||
}
|
|
||||||
z[e] = val;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto xIndex = e < xLen?shape::getIndexOffset(e, inputXshape, xLen):shape::getIndexOffset(xLen, inputXshape, xLen);
|
|
||||||
auto yIndex = e < yLen?shape::getIndexOffset(e, inputYshape, yLen):shape::getIndexOffset(yLen - 1, inputYshape, yLen);
|
|
||||||
auto zIndex = shape::getIndexOffset(e, outputShape, outputLen);
|
|
||||||
z[zIndex] = nd4j::math::nd4j_max(x[xIndex], y[yIndex]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void bdsLoopH(cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
|
|
||||||
bdsLoopKernel<T><<<1, 256, 512, *stream>>>(inputX, inputXshape, inputY, inputYshape, output, outputShape);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
|
|
||||||
//int e = 0, x = 0, y = 0;
|
|
||||||
NDArray::prepareSpecialUse({output}, {x_shape, y_shape});
|
|
||||||
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
|
||||||
x_shape->syncToHost(); y_shape->syncToHost();
|
|
||||||
if (x_shape->lengthOf() == y_shape->lengthOf()) {
|
|
||||||
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
|
|
||||||
output->assign(greater);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
|
|
||||||
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
|
|
||||||
output->assign(greater);
|
|
||||||
auto lastG = greater->lengthOf() - 1;
|
|
||||||
auto lastL = lesser->lengthOf() - 1;
|
|
||||||
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
|
|
||||||
output->p(lastG, lesser->e(lastL));
|
|
||||||
output->syncToDevice();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
//bdsLoopH(context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), y->getSpecialBuffer(), y->getSpecialShape(), output->specialBuffer(), output->specialShapeInfo())
|
|
||||||
BUILD_SINGLE_SELECTOR(output->dataType(), bdsLoopH, (context->getCudaStream(), x_shape->getSpecialBuffer(), x_shape->getSpecialShapeInfo(), y_shape->getSpecialBuffer(), y_shape->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES);
|
|
||||||
}
|
|
||||||
NDArray::registerSpecialUse({output}, {x_shape, y_shape});
|
|
||||||
return Status::OK();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1086,13 +1086,11 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<int>({ 2, 1, 2});
|
auto y = NDArrayFactory::create<int>({ 2, 1, 2});
|
||||||
|
|
||||||
// ------------------------------------
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<int>({2, 2, 2});
|
auto exp = NDArrayFactory::create<int>({2, 2, 2});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
auto res = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1107,7 +1105,6 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
|
auto y = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
|
||||||
|
|
||||||
// ------------------------------------
|
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
@ -1122,17 +1119,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
|
||||||
/////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
|
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<Nd4jLong>( {2, 2, 2} );
|
auto x = NDArrayFactory::create<int>( {2, 2, 2} );
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>({ 2, 1});
|
auto y = NDArrayFactory::create<int>({2, 1});
|
||||||
|
|
||||||
// ------------------------------------
|
auto exp = NDArrayFactory::create<int>({2, 2, 2});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.execute({&x, &y}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
@ -1145,9 +1140,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<Nd4jLong>( {2, 1} );
|
auto x = NDArrayFactory::create<Nd4jLong>( {2, 1} );
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, { 4,});
|
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, {4});
|
||||||
|
|
||||||
// ------------------------------------
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4});
|
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4});
|
||||||
|
|
||||||
|
@ -1161,49 +1154,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
|
||||||
|
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
/////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_5) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 2});
|
|
||||||
|
|
||||||
// ------------------------------------
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
|
||||||
// res->at(0)->printIndexedBuffer("Output");
|
|
||||||
// exp.printIndexedBuffer("Expect");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
|
||||||
|
|
||||||
delete res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_5) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
|
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
|
||||||
|
|
||||||
// ------------------------------------
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
|
||||||
// res->at(0)->printIndexedBuffer("Output SGO 5");
|
|
||||||
// exp.printIndexedBuffer("Expect");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
|
||||||
|
|
||||||
delete res;
|
|
||||||
}
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
|
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
|
||||||
|
|
||||||
|
@ -1211,16 +1162,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
||||||
|
|
||||||
// ------------------------------------
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
// res->at(0)->printIndexedBuffer("Output SGO 6");
|
|
||||||
// exp.printIndexedBuffer("Expect");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
|
||||||
delete res;
|
delete res;
|
||||||
|
@ -1233,16 +1180,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 4, 1});
|
auto y = NDArrayFactory::create<Nd4jLong>({2, 4, 1});
|
||||||
|
|
||||||
// ------------------------------------
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
|
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
// res->at(0)->printIndexedBuffer("Output SGO 7");
|
|
||||||
// exp.printIndexedBuffer("Expect");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
|
||||||
delete res;
|
delete res;
|
||||||
|
@ -1255,19 +1198,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) {
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<int>('c', {1}, {4});
|
auto y = NDArrayFactory::create<int>('c', {1}, {4});
|
||||||
|
|
||||||
// ------------------------------------
|
auto z = NDArrayFactory::create<int>('c', {1});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<int>('c', {1}, {4});
|
auto exp = NDArrayFactory::create<int>('c', {1}, {4});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
// res->at(0)->printIndexedBuffer("Output SGO 8");
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
// exp.printIndexedBuffer("Expect");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
|
||||||
|
|
||||||
delete res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1277,19 +1216,16 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) {
|
||||||
|
|
||||||
auto y = NDArrayFactory::create<int>('c', {1}, {1});
|
auto y = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
// ------------------------------------
|
auto z = NDArrayFactory::create<Nd4jLong>('c', {2});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<int>('c', {2}, {2,2});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {2,2});
|
||||||
|
|
||||||
nd4j::ops::broadcast_dynamic_shape op;
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
res->at(0)->printIndexedBuffer("Output SGO 9");
|
// ASSERT_TRUE(exp.equalsTo(z));
|
||||||
exp.printIndexedBuffer("Expect9");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
|
||||||
|
|
||||||
delete res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
Loading…
Reference in New Issue