diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp index 3fb7c290d..09c8c09ea 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include namespace nd4j { namespace ops { @@ -36,7 +38,7 @@ void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* const int numOfIntArgs = intArgs.size(); - if (indices != nullptr) { + if (indices != nullptr) { // first case: indices consist of only one scalar if(indices->isScalar()) { @@ -46,7 +48,7 @@ void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* auto idx = indices->e(0); auto scalarNDArray = input->e(idx); output->assign(scalarNDArray); - } + } else { NDArray inSubArr = (*input)(indices->e(0), {axis}); output->assign(inSubArr); @@ -54,41 +56,122 @@ void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* } else { - std::vector dimsOut(indices->rankOf()); - std::iota(dimsOut.begin(), dimsOut.end(), axis); // fill with axis, axis+1, ... axis+indices->rankOf()-1 - const Nd4jLong numOfSubArrs = indices->lengthOf(); + if(input->rankOf() == 1 && output->rankOf() == 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) { - NDArray subArrOut = (*output)(i, dimsOut); - NDArray subArrIn = (*input)(indices->e(i), {axis}); - subArrOut.assign(subArrIn); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) + output->p(i, input->e(indices->e(i))); + }; + + samediff::Threads::parallel_for(func, 0, output->lengthOf()); + + } + else { + + std::vector dimsOut; + for (int i = 0; i < axis; ++i) + dimsOut.push_back(i); + for (int i = axis+indices->rankOf(); i < output->rankOf(); ++i) + dimsOut.push_back(i); + + std::vector dimsIn = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + + const Nd4jLong numOfSubArrs = indices->lengthOf(); + + auto inTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimsIn); + auto outTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimsOut); + + Nd4jLong* inTadShapeInfo = inTadPack.primaryShapeInfo(); + Nd4jLong* outTadShapeInfo = outTadPack.primaryShapeInfo(); + + if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && shape::order(inTadShapeInfo) == 'c' && input->dataType() == output->dataType() && shape::elementWiseStride(inTadShapeInfo) == 1 && shape::elementWiseStride(outTadShapeInfo) == 1) { + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; i += increment) { + + void* inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[indices->e(i)]); + void* outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); + } + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } - }; + else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + void* inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[indices->e(i)]); + void* outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny(input->getContext(), transform::Assign, + inBuff, inTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/, + outBuff, outTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/, + nullptr, nullptr, nullptr, false/*allowParallelism*/); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } + } } - } + } else { - + // we only allow scalar/vector case here if (numOfIntArgs == 2) { // scalar case + output->assign((*input)(intArgs[1], {axis})); } else { // vector case + const Nd4jLong numOfSubArrs = intArgs.size() - 1; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) { - NDArray subArrOut = (*output)(i, {axis}); - NDArray subArrIn = (*input)(intArgs[i + 1], {axis}); - subArrOut.assign(subArrIn); - } - }; + std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + + auto inTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dims); + auto outTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dims); + + Nd4jLong* inTadShapeInfo = inTadPack.primaryShapeInfo(); + Nd4jLong* outTadShapeInfo = outTadPack.primaryShapeInfo(); + + if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && shape::order(inTadShapeInfo) == 'c' && input->dataType() == output->dataType() && shape::elementWiseStride(inTadShapeInfo) == 1 && shape::elementWiseStride(outTadShapeInfo) == 1) { + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; i += increment) { + + void* inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[intArgs[i + 1]]); + void* outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); + } + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + + } + else { + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; i += increment) { + + void* inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[intArgs[i + 1]]); + void* outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny(input->getContext(), transform::Assign, + inBuff, inTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/, + outBuff, outTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/, + nullptr, nullptr, nullptr, false/*allowParallelism*/); + + } + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } - } + } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index f47d08b7a..53d18e3cd 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -279,7 +279,7 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) { const DataType zType = z->dataType(); - return block.isUseMKLDNN() && + return block.isUseMKLDNN() && x->rankOf() < 3 && ( (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||