cavis/libnd4j/include/ops/declarable/helpers/cpu/indexReductions.hpp

903 lines
37 KiB
C++
Raw Normal View History

2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf
//
#include <type_traits>
#include <cmath>
#include <stdexcept>
#include <memory>
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <helpers/LoopsCoordsHelper.h>
#include <ops/declarable/helpers/reductions.h>
#if 1
#define LOG_CALLS(X)
#else
#define LOG_CALLS(X) nd4j_printf("___%s_________%d+\n", __PRETTY_FUNCTION__, X);
#endif
namespace sd {
namespace ops {
namespace helpers {
constexpr int threadingThreshold = 4096;
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j], j);
}
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
j_offset += inner_stride;
}
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount)
{
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
argCurrent = 0;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
}
//we skip 1
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
{
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
argCurrent = 0;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, *inner_buffer, j + startIndex);
inner_buffer += inner_stride;
}
//we alreaddy skiped
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount)
{
size_t offset = 0;
Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
Nd4jLong coords[MAX_RANK] = {};
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
if (outerLoopStart > 0) {
sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
offset = sd::offset_from_coords(strides, ptr_coords, rank);
}
Z startIndex = outerLoopStart * innerLoopCount;
argCurrent = startIndex;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
//nd4j_printf("%f\n", inner_buffer[j]);
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
}
offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
{
size_t offset = 0;
Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
Nd4jLong coords[MAX_RANK] = {};
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
if (outerLoopStart > 0) {
sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
offset = sd::offset_from_coords(strides, ptr_coords, rank);
}
Z startIndex = outerLoopStart * innerLoopCount;
argCurrent = startIndex;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[j * inner_stride], startIndex + j);
}
offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
//offset = inc_coords<LastIndexFaster>(bases, strides, ptr_coords, offset, rank, 1);
//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong loopCount4 = loopCount / 4;
Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
const X* buffer1 = buffer + 1 * loopCount4;
const X* buffer2 = buffer1 + 1 * loopCount4;
const X* buffer3 = buffer2 + 1 * loopCount4;
X current1 = *buffer1;
X current2 = *buffer2;
X current3 = *buffer3;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
for (Z j = 0; j < loopCount4; j++) {
ReductionOp::update(current, argCurrent, buffer[j], j);
ReductionOp::update(current1, argCurrent1, buffer1[j], j);
ReductionOp::update(current2, argCurrent2, buffer2[j], j);
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
}
//tail
for (Z j = loopCount4; j < loopCountEnd; j++) {
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
}
//merge
argCurrent1 += loopCount4;
argCurrent2 += 2 * loopCount4;
argCurrent3 += 3 * loopCount4;
ReductionOp::update(current, argCurrent, current1, argCurrent1);
ReductionOp::update(current, argCurrent, current2, argCurrent2);
ReductionOp::update(current, argCurrent, current3, argCurrent3);
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong loopCount4 = loopCount / 4;
Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
const X* buffer1 = buffer + inner_stride * loopCount4;
const X* buffer2 = buffer1 + inner_stride * loopCount4;
const X* buffer3 = buffer2 + inner_stride * loopCount4;
X current1 = *buffer1;
X current2 = *buffer2;
X current3 = *buffer3;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount4; j++) {
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
j_offset += inner_stride;
}
//tail
for (Z j = loopCount4; j < loopCountEnd; j++) {
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
j_offset += inner_stride;
}
//merge
argCurrent1 += loopCount4;
argCurrent2 += 2 * loopCount4;
argCurrent3 += 3 * loopCount4;
ReductionOp::update(current, argCurrent, current1, argCurrent1);
ReductionOp::update(current, argCurrent, current2, argCurrent2);
ReductionOp::update(current, argCurrent, current3, argCurrent3);
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount)
{
LOG_CALLS(0)
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j], j);
ReductionOp::update(current1, argCurrent1, buffer1[j], j);
ReductionOp::update(current2, argCurrent2, buffer2[j], j);
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
{
LOG_CALLS(0)
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
j_offset += inner_stride;
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount)
{
LOG_CALLS(0)
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
//LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
const X* inner_buffer1 = &(buffer1[offset]);
const X* inner_buffer2 = &(buffer2[offset]);
const X* inner_buffer3 = &(buffer3[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
ReductionOp::update(current1, argCurrent1, inner_buffer1[j], j + startIndex);
ReductionOp::update(current2, argCurrent2, inner_buffer2[j], j + startIndex);
ReductionOp::update(current3, argCurrent3, inner_buffer3[j], j + startIndex);
}
//we skip 1
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
{
LOG_CALLS(0)
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
//LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
const X* inner_buffer1 = &(buffer1[offset]);
const X* inner_buffer2 = &(buffer2[offset]);
const X* inner_buffer3 = &(buffer3[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
Nd4jLong inner_offset = 0;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[inner_offset], j + startIndex);
ReductionOp::update(current1, argCurrent1, inner_buffer1[inner_offset], j + startIndex);
ReductionOp::update(current2, argCurrent2, inner_buffer2[inner_offset], j + startIndex);
ReductionOp::update(current3, argCurrent3, inner_buffer3[inner_offset], j + startIndex);
inner_offset += inner_stride;
}
//we skip 1
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
void argIndexCase1Scalar(const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
{
Nd4jLong inner_total;
Nd4jLong inner_last = 0;
int maxThreads = sd::Environment::getInstance().maxMasterThreads();
if (second_rank == 1) {
inner_total = inner_bases[0];
if (inner_total < threadingThreshold) {
maxThreads = 1;
}
}
else {
inner_total = getLength<LastIndexFaster>(inner_bases, second_rank, 1, inner_last);
if (inner_total * inner_last < threadingThreshold) {
maxThreads = 1;
}
}
std::unique_ptr<X[]> maxValues(new X[maxThreads]);
std::unique_ptr<Z[]> maxIndices(new Z[maxThreads]);
X* ptrMaxValues = maxValues.get();
Z* ptrMaxIndices = maxIndices.get();
auto func = [ptrMaxValues, ptrMaxIndices, inner_last, second_rank, inner_bases, inner_strides, bufferX](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
//LOG_CALLS(0)
const Nd4jLong inner_stride = LastIndexFaster ? inner_strides[second_rank - 1] : inner_strides[0];
Z argCurrent; X current;
if (second_rank == 1) {
const Nd4jLong loopTotal = stop - start;
if (inner_stride == 1) {
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start]), current, argCurrent, loopTotal);
}
else {
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start * inner_stride]), current, argCurrent, loopTotal, inner_stride);
}
ptrMaxIndices[thread_id] = argCurrent + start;
}
else {
if (inner_stride == 1) {
indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
}
else {
indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
}
ptrMaxIndices[thread_id] = argCurrent;
}
ptrMaxValues[thread_id] = current;
};
#if 0
int Count = 0;
func(0, 0, inner_total, 1);
#else
int Count = samediff::Threads::parallel_tad(func, 0, inner_total, 1, maxThreads);
#endif
Z arg = 0;
X current = ptrMaxValues[0];
for (Z i = 1; i < Count; i++) {
ReductionOp::update(current, arg, ptrMaxValues[i], i);
}
*outputZ = ptrMaxIndices[arg];
}
template<typename X, typename Z, typename ReductionOp, typename Movement, bool LastIndexFaster = true>
void argReductionInnerCases(Movement& movement, Nd4jLong loopTotal, const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
{
Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];
Nd4jLong loopTotal_K = loopTotal / 4;
Nd4jLong loopTotal_Tail = loopTotal & 3;
if (inner_stride == 1) {
if (second_rank == 1) {
LOG_CALLS(0)
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total);
}
if (inner_total >= 2048) {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
movement.increment();
}
}
else {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
movement.increment();
}
}
}
else {
Nd4jLong inner_last;
Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
if (second_rank == 2) {
LOG_CALLS(1)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last);
movement.increment();
}
}
else if (second_rank == 3) {
LOG_CALLS(2)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
inner_loop, inner_last);
movement.increment();
}
}
else {
LOG_CALLS(3)
//nd4j_printf("-----%d \n", loopTotal);
for (Nd4jLong i = 0; i < loopTotal; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
inner_loop, inner_last);
movement.increment();
}
}
}
}
else {
if (second_rank == 1) {
LOG_CALLS(10)
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total, inner_stride);
}
if (inner_total >= 2048) {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
movement.increment();
}
}
else {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
movement.increment();
}
}
}
else {
Nd4jLong inner_last;
Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
if (second_rank == 2) {
LOG_CALLS(11)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
movement.increment();
}
}
else if (second_rank == 3) {
LOG_CALLS(12)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
movement.increment();
}
}
else {
LOG_CALLS(13)
//nd4j_printf("-------%d inner loop %d inner_last %d\n", loopTotal, inner_loop,inner_last);
for (Nd4jLong i = 0; i < loopTotal; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
inner_loop, inner_last, inner_stride);
movement.increment();
}
}
}
}
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
void argIndexCaseNonScalar(const int& first_rank, const int& output_rank, bool squashed, const int& second_rank,
const Nd4jLong*& outer_bases,const Nd4jLong* outer_strides,const Nd4jLong* output_strides, const Nd4jLong &output_stride,
const Nd4jLong*& inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
{
Nd4jLong total = getLength<LastIndexFaster>(outer_bases, first_rank);
Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];
Nd4jLong outer_stride = LastIndexFaster ? outer_strides[second_rank - 1] : outer_strides[0];
auto func = [first_rank, output_rank, squashed, outer_bases, outer_strides, output_strides, output_stride, second_rank, inner_bases, inner_strides, bufferX, outputZ](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
Nd4jLong loopTotal = stop - start;
Nd4jLong stride = LastIndexFaster ? outer_strides[first_rank - 1] : outer_strides[0];
if (first_rank == 1) {
if (stride == 1) {
ZipGenericCoordsRank1Stride1 movement;
movement.init(nullptr, nullptr, nullptr, 0, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
ZipGenericCoordsRank1BothStrideN movement;
movement.init(nullptr, &stride, &output_stride, 0, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else if (squashed && first_rank <= output_rank) {
if (first_rank == 2) {
if (output_stride == 1) {
ZipGenericCoordsConstMovementSecondStride1<2, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
ZipGenericCoordsConstMovementSecondStrideN<2, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else if (first_rank == 3) {
if (output_stride == 1) {
ZipGenericCoordsConstMovementSecondStride1<3, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
ZipGenericCoordsConstMovementSecondStrideN<3, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
ZipGenericCoordsMovementSecondStrideN< LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
ZipGenericCoordsMovement<LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, output_strides, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
};
#if 0
func(0, 0, total, 1);
#else
//
uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads();
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
if (total * inner_total <= threadingThreshold) {
numThreads = 1;
}
else {
if (inner_stride > outer_stride && total <= 256) {
auto desired = total > 4 ? (total / 4) : 1;
numThreads = numThreads > desired ? desired : numThreads;
}
}
samediff::Threads::parallel_tad(func, 0, total, 1, numThreads);
#endif
}
template<typename X, typename Z, typename ReductionOp>
void argIndex_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
char input_order = input.ordering();
bool try_squash_outer = (input_order == output.ordering()) && output.ews() != 0;
const Nd4jLong* input_shapeInfo = input.shapeInfo();
const Nd4jLong* output_shapeInfo = output.shapeInfo();
const Nd4jLong rank = input_shapeInfo[0];
const Nd4jLong* input_bases = &(input_shapeInfo[1]);
const Nd4jLong* input_strides = &(input_shapeInfo[rank + 1]);
const Nd4jLong output_rank = output_shapeInfo[0];
const Nd4jLong* output_strides = &(output_shapeInfo[output_rank + 1]);
Nd4jLong new_bases[MAX_RANK];
Nd4jLong new_strides[MAX_RANK];
int first_begin, first_end, second_begin, second_end;
//rePartition into two parts based on the selection
rePartition(input_order, dimensions, rank, input_bases, input_strides, new_bases, new_strides, first_begin, first_end, second_begin, second_end, try_squash_outer, input_order == 'c');
int first_rank = first_end - first_begin; //the first rank can be 0 for scalar cases
int second_rank = second_end - second_begin;
auto bufferX = input.bufferAsT<X>();
auto outputZ = output.bufferAsT<Z>();
const Nd4jLong* outer_bases = &(new_bases[first_begin]);
const Nd4jLong* outer_strides = &(new_strides[first_begin]);
const Nd4jLong* inner_bases = &(new_bases[second_begin]);
const Nd4jLong* inner_strides = &(new_strides[second_begin]);
const Nd4jLong output_stride = output.ordering() == 'c' ? output_strides[output_rank-1]:output_strides[0];
if (input_order == 'c') {
if (first_rank == 0) {
argIndexCase1Scalar<X, Z, ReductionOp>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
argIndexCaseNonScalar<X, Z, ReductionOp>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
output_stride,inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
if (first_rank == 0) {
LOG_CALLS(0);
if (second_rank == 1) {
argIndexCase1Scalar<X, Z, ReductionOp, false>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
argIndexCase1Scalar<X, Z, ReductionOp, true>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
LOG_CALLS(1);
argIndexCaseNonScalar<X, Z, ReductionOp,false>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
output_stride, inner_bases, inner_strides, bufferX, outputZ);
}
}
}
template <typename X, typename Z>
struct IndexMax {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
if (candidate > current) {
current = candidate;
currentIndex = candidateIndex;
}
}
};
template <typename X, typename Z>
struct IndexMin {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
if (candidate < current) {
current = candidate;
currentIndex = candidateIndex;
}
}
};
template <typename X, typename Z>
struct IndexAbsMax {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
auto absCandidate = sd::math::nd4j_abs<X>(candidate);
if (absCandidate > current) {
current = absCandidate;
currentIndex = candidateIndex;
}
}
};
template <typename X, typename Z>
struct IndexAbsMin {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
auto absCandidate = sd::math::nd4j_abs<X>(candidate);
if (absCandidate < current) {
current = absCandidate;
currentIndex = candidateIndex;
}
}
};
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
void argMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexMax<X, Z>>(input, output, dimensions);
}
template<typename X, typename Z>
void argMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexMin<X, Z>>(input, output, dimensions);
}
template<typename X, typename Z>
void argAbsMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexAbsMax<X, Z>>(input, output, dimensions);
}
template<typename X, typename Z>
void argAbsMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexAbsMin<X, Z>>(input, output, dimensions);
}
}
}
}