Oleh true broadcast opt (#234)
* libnd4j trueBroadcast special case Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fix trueBroadcast special case Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j special case of TrueBroadcastHelper Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j trueBroadCast special case and test * libnd4j minor changes sync with master * libnd4j changes to TrueBroadcastHelper.hpp per require Signed-off-by: Oleg <oleg.semeniv@gmail.com>master
parent
b941186302
commit
11cb561045
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "x64-Debug",
|
||||
|
|
|
@ -32,6 +32,7 @@ template <typename X, typename Y, typename Z>
|
|||
template<typename OpType>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
@ -44,8 +45,26 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
|
|||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
|
||||
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
||||
1 == zArr.ews() && 'c' == zArr.ordering());
|
||||
|
||||
if (bSpecialCase) {
|
||||
auto yLen = (uint32_t)yArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (uint32_t i = start; i < stop; i++) {
|
||||
auto rZ = z + (i * yLen);
|
||||
auto v = x[i];
|
||||
for (uint32_t j = 0; j < yLen; j++) {
|
||||
rZ[j] = OpType::op(v, y[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||
return;
|
||||
}
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
|
|
|
@ -532,3 +532,24 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
|
||||
|
||||
auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32);
|
||||
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
y.assign(1.0);
|
||||
x.linspace(1.0);
|
||||
|
||||
nd4j::ops::add op;
|
||||
auto result = op.evaluate({ &x, &y });
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto res = *result->at(0);
|
||||
|
||||
ASSERT_EQ(e, res);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue