- rewrite broadcast_dynamic_shape and delete corresponding helpers (#194)
Signed-off-by: Yurii <yurii@skymind.io>master
parent
0463ee4eba
commit
5395d4fbe5
|
@ -15,54 +15,79 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_broadcast_dynamic_shape)
|
||||
|
||||
//#include <ops/declarable/headers/parity_ops.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/bds.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
DECLARE_TYPES(broadcast_dynamic_shape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedOutputTypes({ALL_INTS})
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
namespace ops {
|
||||
|
||||
CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) {
|
||||
auto x_shape = INPUT_VARIABLE(0);
|
||||
auto y_shape = INPUT_VARIABLE(1);
|
||||
|
||||
REQUIRE_TRUE(shape::isVector(x_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The first argument should be a vector");
|
||||
REQUIRE_TRUE(shape::isVector(y_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The second argument should be a vector");
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) {
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
return helpers::bdsFunctor(block.launchContext(), x_shape, y_shape, output);
|
||||
}
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
DECLARE_SHAPE_FN(broadcast_dynamic_shape) {
|
||||
auto shapeList = SHAPELIST();
|
||||
|
||||
auto theFirst = inputShape->at(0);
|
||||
auto theSecond = inputShape->at(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
auto theFirstLen = shape::sizeAt(theFirst, -1);
|
||||
auto theSecondLen = shape::sizeAt(theSecond, -1);
|
||||
REQUIRE_TRUE(x->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have rank = 1, but got %i instead!", x->rankOf());
|
||||
REQUIRE_TRUE(y->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have rank = 1, but got %i instead!", y->rankOf());
|
||||
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the same integer type !");
|
||||
|
||||
auto shapeLength = nd4j::math::nd4j_max(theFirstLen, theSecondLen);
|
||||
// contract shapeInfos, neglect and don't fill strides, ews, order
|
||||
// shapes are of interest only
|
||||
std::vector<Nd4jLong> xShapeInfo(shape::shapeInfoLength(x->lengthOf()));
|
||||
std::vector<Nd4jLong> yShapeInfo(shape::shapeInfoLength(y->lengthOf()));
|
||||
|
||||
auto newshape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shapeLength, ArrayOptions::dataType(theFirst));
|
||||
shapeList->push_back(newshape);
|
||||
return shapeList;
|
||||
}
|
||||
// fill rank and data type
|
||||
xShapeInfo[0] = x->lengthOf();
|
||||
yShapeInfo[0] = y->lengthOf();
|
||||
ArrayOptions::setDataType(xShapeInfo.data(), nd4j::DataType::INT64); // fill with some data type, it doesn't matter what type exactly to choose
|
||||
ArrayOptions::setDataType(yShapeInfo.data(), nd4j::DataType::INT64);
|
||||
|
||||
}
|
||||
for (Nd4jLong i = 0; i < x->lengthOf(); ++i)
|
||||
xShapeInfo[i + 1] = x->e<Nd4jLong>(i);
|
||||
|
||||
for (Nd4jLong i = 0; i < y->lengthOf(); ++i)
|
||||
yShapeInfo[i + 1] = y->e<Nd4jLong>(i);
|
||||
|
||||
Nd4jLong* poinerOnOutShapeInfo = nullptr;
|
||||
|
||||
const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo(xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, block.launchContext()->getWorkspace());
|
||||
|
||||
REQUIRE_TRUE(isBroadcastPossible, 0, "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s and %s are not suitable for broadcast operation !", ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), ShapeUtils::shapeAsString(yShapeInfo.data()).c_str());
|
||||
|
||||
for (Nd4jLong i = 0; i < z->lengthOf(); ++i)
|
||||
z->p<Nd4jLong>(i, poinerOnOutShapeInfo[i + 1]);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(broadcast_dynamic_shape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedOutputTypes({ALL_INTS})
|
||||
->setAllowedInputTypes({ALL_INTS});
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(broadcast_dynamic_shape) {
|
||||
|
||||
const int xRank = INPUT_VARIABLE(0)->lengthOf();
|
||||
const int yRank = INPUT_VARIABLE(1)->lengthOf();
|
||||
|
||||
const int maxRank = xRank > yRank ? xRank : yRank;
|
||||
|
||||
auto outputShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxRank, ArrayOptions::dataType(inputShape->at(0)));
|
||||
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,33 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
//
|
||||
#ifndef __BDS_H_HELPERS__
|
||||
#define __BDS_H_HELPERS__
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -1,72 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/bds.h>
|
||||
#include <Status.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
|
||||
|
||||
|
||||
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
||||
// lenght are equals
|
||||
if (x_shape->lengthOf() == y_shape->lengthOf()) {
|
||||
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
|
||||
output->assign(greater);
|
||||
}
|
||||
else {
|
||||
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
|
||||
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
|
||||
output->assign(greater);
|
||||
auto lastG = greater->lengthOf() - 1;
|
||||
auto lastL = lesser->lengthOf() - 1;
|
||||
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
|
||||
output->p(lastG, lesser->e(lastL));
|
||||
}
|
||||
}
|
||||
else {
|
||||
//int e = 0, x = 0, y = 0;
|
||||
Nd4jLong xLen = x_shape->lengthOf();
|
||||
Nd4jLong yLen = y_shape->lengthOf();
|
||||
Nd4jLong zLen = output->lengthOf();
|
||||
Nd4jLong borderLen = nd4j::math::nd4j_min(xLen, yLen);
|
||||
for (Nd4jLong e = 0; e < zLen; e++) {
|
||||
Nd4jLong val;
|
||||
if (e < borderLen) {
|
||||
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(e), y_shape->e<Nd4jLong>(e));
|
||||
} else if (e < xLen) {
|
||||
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(e), y_shape->e<Nd4jLong>(yLen - 1));
|
||||
} else {
|
||||
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(xLen - 1), y_shape->e<Nd4jLong>(e));
|
||||
}
|
||||
|
||||
output->p(e, val);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,113 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/bds.h>
|
||||
#include <Status.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
template <typename T>
|
||||
static __global__ void bdsLoopKernel(void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T const* x;
|
||||
__shared__ T const* y;
|
||||
__shared__ T* z;
|
||||
__shared__ bool speedWay;
|
||||
//__shared__ int indexX, indexY;
|
||||
__shared__ Nd4jLong xLen, yLen, outputLen;
|
||||
if (threadIdx.x == 0) {
|
||||
x = reinterpret_cast<T const*>(inputX);
|
||||
y = reinterpret_cast<T const*>(inputY);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
xLen = shape::length(inputXshape);
|
||||
yLen = shape::length(inputYshape);
|
||||
outputLen = shape::length(outputShape);
|
||||
speedWay = true;
|
||||
speedWay = speedWay && (shape::elementWiseStride(inputXshape) == 1);
|
||||
speedWay = speedWay && (shape::elementWiseStride(inputYshape) == 1);
|
||||
speedWay = speedWay && (shape::elementWiseStride(outputShape) == 1);
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
for (int e = tid; e < outputLen; e += step) {
|
||||
T val;
|
||||
if (speedWay) {
|
||||
if (e < nd4j::math::nd4j_min(yLen, xLen)) {
|
||||
val = nd4j::math::nd4j_max(x[e], y[e]);
|
||||
} else if (e < xLen) {
|
||||
val = nd4j::math::nd4j_max(x[e], y[yLen - 1]);
|
||||
} else {
|
||||
val = nd4j::math::nd4j_max(x[xLen - 1], y[e]);
|
||||
}
|
||||
z[e] = val;
|
||||
}
|
||||
else {
|
||||
auto xIndex = e < xLen?shape::getIndexOffset(e, inputXshape, xLen):shape::getIndexOffset(xLen, inputXshape, xLen);
|
||||
auto yIndex = e < yLen?shape::getIndexOffset(e, inputYshape, yLen):shape::getIndexOffset(yLen - 1, inputYshape, yLen);
|
||||
auto zIndex = shape::getIndexOffset(e, outputShape, outputLen);
|
||||
z[zIndex] = nd4j::math::nd4j_max(x[xIndex], y[yIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void bdsLoopH(cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
|
||||
bdsLoopKernel<T><<<1, 256, 512, *stream>>>(inputX, inputXshape, inputY, inputYshape, output, outputShape);
|
||||
|
||||
}
|
||||
|
||||
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
|
||||
//int e = 0, x = 0, y = 0;
|
||||
NDArray::prepareSpecialUse({output}, {x_shape, y_shape});
|
||||
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
||||
x_shape->syncToHost(); y_shape->syncToHost();
|
||||
if (x_shape->lengthOf() == y_shape->lengthOf()) {
|
||||
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
|
||||
output->assign(greater);
|
||||
}
|
||||
else {
|
||||
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
|
||||
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
|
||||
output->assign(greater);
|
||||
auto lastG = greater->lengthOf() - 1;
|
||||
auto lastL = lesser->lengthOf() - 1;
|
||||
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
|
||||
output->p(lastG, lesser->e(lastL));
|
||||
output->syncToDevice();
|
||||
}
|
||||
}
|
||||
else {
|
||||
//bdsLoopH(context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), y->getSpecialBuffer(), y->getSpecialShape(), output->specialBuffer(), output->specialShapeInfo())
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), bdsLoopH, (context->getCudaStream(), x_shape->getSpecialBuffer(), x_shape->getSpecialShapeInfo(), y_shape->getSpecialBuffer(), y_shape->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {x_shape, y_shape});
|
||||
return Status::OK();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1086,13 +1086,11 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {
|
|||
|
||||
auto y = NDArrayFactory::create<int>({ 2, 1, 2});
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
auto exp = NDArrayFactory::create<int>({2, 2, 2});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
||||
auto res = op.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
|
@ -1107,7 +1105,6 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
|
|||
|
||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
|
||||
|
||||
// ------------------------------------
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
|
@ -1122,17 +1119,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
|
|||
/////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
|
||||
|
||||
auto x = NDArrayFactory::create<Nd4jLong>( {2, 2, 2} );
|
||||
auto x = NDArrayFactory::create<int>( {2, 2, 2} );
|
||||
|
||||
auto y = NDArrayFactory::create<Nd4jLong>({ 2, 1});
|
||||
auto y = NDArrayFactory::create<int>({2, 1});
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
||||
auto exp = NDArrayFactory::create<int>({2, 2, 2});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||
auto res = op.execute({&x, &y}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
|
@ -1145,9 +1140,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
|
|||
|
||||
auto x = NDArrayFactory::create<Nd4jLong>( {2, 1} );
|
||||
|
||||
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, { 4,});
|
||||
|
||||
// ------------------------------------
|
||||
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, {4});
|
||||
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4});
|
||||
|
||||
|
@ -1161,49 +1154,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
|
|||
|
||||
delete res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_5) {
|
||||
|
||||
auto x = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
||||
|
||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 2});
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
// res->at(0)->printIndexedBuffer("Output");
|
||||
// exp.printIndexedBuffer("Expect");
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
|
||||
delete res;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_5) {
|
||||
|
||||
auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
|
||||
|
||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
// res->at(0)->printIndexedBuffer("Output SGO 5");
|
||||
// exp.printIndexedBuffer("Expect");
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
|
||||
delete res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
|
||||
|
||||
|
@ -1211,16 +1162,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
|
|||
|
||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
// res->at(0)->printIndexedBuffer("Output SGO 6");
|
||||
// exp.printIndexedBuffer("Expect");
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
|
||||
delete res;
|
||||
|
@ -1233,16 +1180,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
|
|||
|
||||
auto y = NDArrayFactory::create<Nd4jLong>({2, 4, 1});
|
||||
|
||||
// ------------------------------------
|
||||
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
// res->at(0)->printIndexedBuffer("Output SGO 7");
|
||||
// exp.printIndexedBuffer("Expect");
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
|
||||
delete res;
|
||||
|
@ -1255,19 +1198,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) {
|
|||
|
||||
auto y = NDArrayFactory::create<int>('c', {1}, {4});
|
||||
|
||||
// ------------------------------------
|
||||
auto z = NDArrayFactory::create<int>('c', {1});
|
||||
|
||||
auto exp = NDArrayFactory::create<int>('c', {1}, {4});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
||||
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
// res->at(0)->printIndexedBuffer("Output SGO 8");
|
||||
// exp.printIndexedBuffer("Expect");
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
|
||||
delete res;
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1277,19 +1216,16 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) {
|
|||
|
||||
auto y = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
|
||||
// ------------------------------------
|
||||
auto z = NDArrayFactory::create<Nd4jLong>('c', {2});
|
||||
|
||||
auto exp = NDArrayFactory::create<int>('c', {2}, {2,2});
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {2,2});
|
||||
|
||||
nd4j::ops::broadcast_dynamic_shape op;
|
||||
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
||||
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||
res->at(0)->printIndexedBuffer("Output SGO 9");
|
||||
exp.printIndexedBuffer("Expect9");
|
||||
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
// ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
Loading…
Reference in New Issue