Oleh true broadcast opt (#234)
* libnd4j trueBroadcast special case Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fix trueBroadcast special case Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j special case of TrueBroadcastHelper Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j trueBroadCast special case and test * libnd4j minor changes sync with master * libnd4j changes to TrueBroadcastHelper.hpp per require Signed-off-by: Oleg <oleg.semeniv@gmail.com>master
parent
b941186302
commit
11cb561045
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "x64-Debug",
|
"name": "x64-Debug",
|
||||||
|
|
|
@ -32,9 +32,10 @@ template <typename X, typename Y, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
|
||||||
|
|
||||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||||
|
|
||||||
const auto xShapeInfo = xArr.getShapeInfo();
|
const auto xShapeInfo = xArr.getShapeInfo();
|
||||||
const auto yShapeInfo = yArr.getShapeInfo();
|
const auto yShapeInfo = yArr.getShapeInfo();
|
||||||
|
@ -44,8 +45,26 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
|
||||||
const int yRank = yArr.rankOf();
|
const int yRank = yArr.rankOf();
|
||||||
const int zRank = zArr.rankOf();
|
const int zRank = zArr.rankOf();
|
||||||
|
|
||||||
const Nd4jLong zLen = zArr.lengthOf();
|
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
|
||||||
|
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
||||||
|
1 == zArr.ews() && 'c' == zArr.ordering());
|
||||||
|
|
||||||
|
if (bSpecialCase) {
|
||||||
|
auto yLen = (uint32_t)yArr.lengthOf();
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (uint32_t i = start; i < stop; i++) {
|
||||||
|
auto rZ = z + (i * yLen);
|
||||||
|
auto v = x[i];
|
||||||
|
for (uint32_t j = 0; j < yLen; j++) {
|
||||||
|
rZ[j] = OpType::op(v, y[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Nd4jLong zLen = zArr.lengthOf();
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
|
|
|
@ -532,3 +532,24 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
|
||||||
|
|
||||||
|
auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32);
|
||||||
|
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
y.assign(1.0);
|
||||||
|
x.linspace(1.0);
|
||||||
|
|
||||||
|
nd4j::ops::add op;
|
||||||
|
auto result = op.evaluate({ &x, &y });
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto res = *result->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, res);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue