/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author AbdelRauf // #include #include #include #include #include #include #include #include #if 1 #define LOG_CALLS(X) #else #define LOG_CALLS(X) nd4j_printf("___%s_________%d+\n", __PRETTY_FUNCTION__, X); #endif namespace sd { namespace ops { namespace helpers { constexpr int threadingThreshold = 4096; template FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount) { argCurrent = 0; current = buffer[0]; LOG_CALLS(0) Nd4jLong j_offset = 0; for (Z j = 0; j < loopCount; j++) { ReductionOp::update(current, argCurrent, buffer[j], j); } } template FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride) { argCurrent = 0; current = buffer[0]; LOG_CALLS(0) Nd4jLong j_offset = 0; for (Z j = 0; j < loopCount; j++) { ReductionOp::update(current, argCurrent, buffer[j_offset], j); j_offset += inner_stride; } } template FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount) { //skip 1 from the beginning or end depending the Order constexpr size_t updated_index = LastIndexFaster ? 0 : 1; constexpr size_t updated_rank = constRank - 1; sd::CoordsState cst; //we skip 1 size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); Z startIndex = 0; argCurrent = 0; current = buffer[offset]; LOG_CALLS(0) for (Z i = 0; i < outerLoopCount; i++) { const X* inner_buffer = &(buffer[offset]); //typename std::make_signed::type iArgMax = -1; for (Z j = 0; j < innerLoopCount; j++) { ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex); } //we skip 1 offset = sd::inc_coords(cst, offset); startIndex += innerLoopCount; } } template FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride) { //skip 1 from the beginning or end depending the Order constexpr size_t updated_index = LastIndexFaster ? 0 : 1; constexpr size_t updated_rank = constRank - 1; sd::CoordsState cst; //we skip 1 size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); Z startIndex = 0; argCurrent = 0; current = buffer[offset]; LOG_CALLS(0) for (Z i = 0; i < outerLoopCount; i++) { const X* inner_buffer = &(buffer[offset]); for (Z j = 0; j < innerLoopCount; j++) { ReductionOp::update(current, argCurrent, *inner_buffer, j + startIndex); inner_buffer += inner_stride; } //we alreaddy skiped offset = sd::inc_coords(cst, offset); startIndex += innerLoopCount; } } template FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount) { size_t offset = 0; Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart; Nd4jLong coords[MAX_RANK] = {}; Nd4jLong* ptr_coords = (Nd4jLong*)&coords; if (outerLoopStart > 0) { sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords); offset = sd::offset_from_coords(strides, ptr_coords, rank); } Z startIndex = outerLoopStart * innerLoopCount; argCurrent = startIndex; current = buffer[offset]; LOG_CALLS(0) for (Z i = 0; i < outerLoopCount; i++) { const X* inner_buffer = &(buffer[offset]); //typename std::make_signed::type iArgMax = -1; for (Z j = 0; j < innerLoopCount; j++) { //nd4j_printf("%f\n", inner_buffer[j]); ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex); } offset = inc_coords(bases, strides, ptr_coords, offset, rank, 1); //if (iArgMax >= 0) argCurrent = startIndex + iArgMax; startIndex += innerLoopCount; } } template FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride) { size_t offset = 0; Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart; Nd4jLong coords[MAX_RANK] = {}; Nd4jLong* ptr_coords = (Nd4jLong*)&coords; if (outerLoopStart > 0) { sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords); offset = sd::offset_from_coords(strides, ptr_coords, rank); } Z startIndex = outerLoopStart * innerLoopCount; argCurrent = startIndex; current = buffer[offset]; LOG_CALLS(0) for (Z i = 0; i < outerLoopCount; i++) { const X* inner_buffer = &(buffer[offset]); //typename std::make_signed::type iArgMax = -1; for (Z j = 0; j < innerLoopCount; j++) { ReductionOp::update(current, argCurrent, inner_buffer[j * inner_stride], startIndex + j); } offset = inc_coords(bases, strides, ptr_coords, offset, rank, 1); //offset = inc_coords(bases, strides, ptr_coords, offset, rank, 1); //if (iArgMax >= 0) argCurrent = startIndex + iArgMax; startIndex += innerLoopCount; } } template FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount) { argCurrent = 0; current = buffer[0]; LOG_CALLS(0) Nd4jLong loopCount4 = loopCount / 4; Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3); const X* buffer1 = buffer + 1 * loopCount4; const X* buffer2 = buffer1 + 1 * loopCount4; const X* buffer3 = buffer2 + 1 * loopCount4; X current1 = *buffer1; X current2 = *buffer2; X current3 = *buffer3; Z argCurrent1 = 0; Z argCurrent2 = 0; Z argCurrent3 = 0; for (Z j = 0; j < loopCount4; j++) { ReductionOp::update(current, argCurrent, buffer[j], j); ReductionOp::update(current1, argCurrent1, buffer1[j], j); ReductionOp::update(current2, argCurrent2, buffer2[j], j); ReductionOp::update(current3, argCurrent3, buffer3[j], j); } //tail for (Z j = loopCount4; j < loopCountEnd; j++) { ReductionOp::update(current3, argCurrent3, buffer3[j], j); } //merge argCurrent1 += loopCount4; argCurrent2 += 2 * loopCount4; argCurrent3 += 3 * loopCount4; ReductionOp::update(current, argCurrent, current1, argCurrent1); ReductionOp::update(current, argCurrent, current2, argCurrent2); ReductionOp::update(current, argCurrent, current3, argCurrent3); } template FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride) { argCurrent = 0; current = buffer[0]; LOG_CALLS(0) Nd4jLong loopCount4 = loopCount / 4; Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3); const X* buffer1 = buffer + inner_stride * loopCount4; const X* buffer2 = buffer1 + inner_stride * loopCount4; const X* buffer3 = buffer2 + inner_stride * loopCount4; X current1 = *buffer1; X current2 = *buffer2; X current3 = *buffer3; Z argCurrent1 = 0; Z argCurrent2 = 0; Z argCurrent3 = 0; Nd4jLong j_offset = 0; for (Z j = 0; j < loopCount4; j++) { ReductionOp::update(current, argCurrent, buffer[j_offset], j); ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j); ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j); ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j); j_offset += inner_stride; } //tail for (Z j = loopCount4; j < loopCountEnd; j++) { ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j); j_offset += inner_stride; } //merge argCurrent1 += loopCount4; argCurrent2 += 2 * loopCount4; argCurrent3 += 3 * loopCount4; ReductionOp::update(current, argCurrent, current1, argCurrent1); ReductionOp::update(current, argCurrent, current2, argCurrent2); ReductionOp::update(current, argCurrent, current3, argCurrent3); } template FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount) { LOG_CALLS(0) Z argCurrent = 0; Z argCurrent1 = 0; Z argCurrent2 = 0; Z argCurrent3 = 0; X current = buffer[0]; X current1 = buffer1[0]; X current2 = buffer2[0]; X current3 = buffer3[0]; for (Z j = 0; j < loopCount; j++) { ReductionOp::update(current, argCurrent, buffer[j], j); ReductionOp::update(current1, argCurrent1, buffer1[j], j); ReductionOp::update(current2, argCurrent2, buffer2[j], j); ReductionOp::update(current3, argCurrent3, buffer3[j], j); } *output = argCurrent; *output1 = argCurrent1; *output2 = argCurrent2; *output3 = argCurrent3; return; } template FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount, const Nd4jLong& inner_stride) { LOG_CALLS(0) Z argCurrent = 0; Z argCurrent1 = 0; Z argCurrent2 = 0; Z argCurrent3 = 0; X current = buffer[0]; X current1 = buffer1[0]; X current2 = buffer2[0]; X current3 = buffer3[0]; Nd4jLong j_offset = 0; for (Z j = 0; j < loopCount; j++) { ReductionOp::update(current, argCurrent, buffer[j_offset], j); ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j); ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j); ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j); j_offset += inner_stride; } *output = argCurrent; *output1 = argCurrent1; *output2 = argCurrent2; *output3 = argCurrent3; return; } template FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount) { LOG_CALLS(0) //skip 1 from the beginning or end depending the Order constexpr size_t updated_index = LastIndexFaster ? 0 : 1; constexpr size_t updated_rank = constRank - 1; sd::CoordsState cst; //we skip 1 size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); Z startIndex = 0; Z argCurrent = 0; Z argCurrent1 = 0; Z argCurrent2 = 0; Z argCurrent3 = 0; X current = buffer[0]; X current1 = buffer1[0]; X current2 = buffer2[0]; X current3 = buffer3[0]; //LOG_CALLS(0) for (Z i = 0; i < outerLoopCount; i++) { const X* inner_buffer = &(buffer[offset]); const X* inner_buffer1 = &(buffer1[offset]); const X* inner_buffer2 = &(buffer2[offset]); const X* inner_buffer3 = &(buffer3[offset]); //typename std::make_signed::type iArgMax = -1; for (Z j = 0; j < innerLoopCount; j++) { ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex); ReductionOp::update(current1, argCurrent1, inner_buffer1[j], j + startIndex); ReductionOp::update(current2, argCurrent2, inner_buffer2[j], j + startIndex); ReductionOp::update(current3, argCurrent3, inner_buffer3[j], j + startIndex); } //we skip 1 offset = sd::inc_coords(cst, offset); startIndex += innerLoopCount; } *output = argCurrent; *output1 = argCurrent1; *output2 = argCurrent2; *output3 = argCurrent3; return; } template FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride) { LOG_CALLS(0) //skip 1 from the beginning or end depending the Order constexpr size_t updated_index = LastIndexFaster ? 0 : 1; constexpr size_t updated_rank = constRank - 1; sd::CoordsState cst; //we skip 1 size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); Z startIndex = 0; Z argCurrent = 0; Z argCurrent1 = 0; Z argCurrent2 = 0; Z argCurrent3 = 0; X current = buffer[0]; X current1 = buffer1[0]; X current2 = buffer2[0]; X current3 = buffer3[0]; //LOG_CALLS(0) for (Z i = 0; i < outerLoopCount; i++) { const X* inner_buffer = &(buffer[offset]); const X* inner_buffer1 = &(buffer1[offset]); const X* inner_buffer2 = &(buffer2[offset]); const X* inner_buffer3 = &(buffer3[offset]); //typename std::make_signed::type iArgMax = -1; Nd4jLong inner_offset = 0; for (Z j = 0; j < innerLoopCount; j++) { ReductionOp::update(current, argCurrent, inner_buffer[inner_offset], j + startIndex); ReductionOp::update(current1, argCurrent1, inner_buffer1[inner_offset], j + startIndex); ReductionOp::update(current2, argCurrent2, inner_buffer2[inner_offset], j + startIndex); ReductionOp::update(current3, argCurrent3, inner_buffer3[inner_offset], j + startIndex); inner_offset += inner_stride; } //we skip 1 offset = sd::inc_coords(cst, offset); startIndex += innerLoopCount; } *output = argCurrent; *output1 = argCurrent1; *output2 = argCurrent2; *output3 = argCurrent3; return; } template void argIndexCase1Scalar(const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ) { Nd4jLong inner_total; Nd4jLong inner_last = 0; int maxThreads = sd::Environment::getInstance().maxMasterThreads(); if (second_rank == 1) { inner_total = inner_bases[0]; if (inner_total < threadingThreshold) { maxThreads = 1; } } else { inner_total = getLength(inner_bases, second_rank, 1, inner_last); if (inner_total * inner_last < threadingThreshold) { maxThreads = 1; } } std::unique_ptr maxValues(new X[maxThreads]); std::unique_ptr maxIndices(new Z[maxThreads]); X* ptrMaxValues = maxValues.get(); Z* ptrMaxIndices = maxIndices.get(); auto func = [ptrMaxValues, ptrMaxIndices, inner_last, second_rank, inner_bases, inner_strides, bufferX](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { //LOG_CALLS(0) const Nd4jLong inner_stride = LastIndexFaster ? inner_strides[second_rank - 1] : inner_strides[0]; Z argCurrent; X current; if (second_rank == 1) { const Nd4jLong loopTotal = stop - start; if (inner_stride == 1) { indexInnerReductionRank1Block4WithMerge(&(bufferX[start]), current, argCurrent, loopTotal); } else { indexInnerReductionRank1Block4WithMerge(&(bufferX[start * inner_stride]), current, argCurrent, loopTotal, inner_stride); } ptrMaxIndices[thread_id] = argCurrent + start; } else { if (inner_stride == 1) { indexInnerReduction(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride); } else { indexInnerReduction(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride); } ptrMaxIndices[thread_id] = argCurrent; } ptrMaxValues[thread_id] = current; }; #if 0 int Count = 0; func(0, 0, inner_total, 1); #else int Count = samediff::Threads::parallel_tad(func, 0, inner_total, 1, maxThreads); #endif Z arg = 0; X current = ptrMaxValues[0]; for (Z i = 1; i < Count; i++) { ReductionOp::update(current, arg, ptrMaxValues[i], i); } *outputZ = ptrMaxIndices[arg]; } template void argReductionInnerCases(Movement& movement, Nd4jLong loopTotal, const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ) { Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0]; Nd4jLong loopTotal_K = loopTotal / 4; Nd4jLong loopTotal_Tail = loopTotal & 3; if (inner_stride == 1) { if (second_rank == 1) { LOG_CALLS(0) Nd4jLong inner_total = getLength(inner_bases, second_rank); for (Nd4jLong i = 0; i < loopTotal_K; i++) { const X* buffer0 = &(bufferX[movement.First()]); Z* output0 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer1 = &(bufferX[movement.First()]); Z* output1 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer2 = &(bufferX[movement.First()]); Z* output2 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer3 = &(bufferX[movement.First()]); Z* output3 = &(outputZ[movement.Second()]); movement.increment(); indexInnerReductionRank1Block4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total); } if (inner_total >= 2048) { for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionRank1Block4WithMerge(buffer0, current, outputZ[movement.Second()], inner_total); movement.increment(); } } else { for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionRank1(buffer0, current, outputZ[movement.Second()], inner_total); movement.increment(); } } } else { Nd4jLong inner_last; Nd4jLong inner_loop = getLength(inner_bases, second_rank, 1, inner_last); if (second_rank == 2) { LOG_CALLS(1) for (Nd4jLong i = 0; i < loopTotal_K; i++) { const X* buffer0 = &(bufferX[movement.First()]); Z* output0 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer1 = &(bufferX[movement.First()]); Z* output1 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer2 = &(bufferX[movement.First()]); Z* output2 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer3 = &(bufferX[movement.First()]); Z* output3 = &(outputZ[movement.Second()]); movement.increment(); indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, inner_loop, inner_last); } for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last); movement.increment(); } } else if (second_rank == 3) { LOG_CALLS(2) for (Nd4jLong i = 0; i < loopTotal_K; i++) { const X* buffer0 = &(bufferX[movement.First()]); Z* output0 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer1 = &(bufferX[movement.First()]); Z* output1 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer2 = &(bufferX[movement.First()]); Z* output2 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer3 = &(bufferX[movement.First()]); Z* output3 = &(outputZ[movement.Second()]); movement.increment(); indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, inner_loop, inner_last); } for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last); movement.increment(); } } else { LOG_CALLS(3) //nd4j_printf("-----%d \n", loopTotal); for (Nd4jLong i = 0; i < loopTotal; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReduction(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0, inner_loop, inner_last); movement.increment(); } } } } else { if (second_rank == 1) { LOG_CALLS(10) Nd4jLong inner_total = getLength(inner_bases, second_rank); for (Nd4jLong i = 0; i < loopTotal_K; i++) { const X* buffer0 = &(bufferX[movement.First()]); Z* output0 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer1 = &(bufferX[movement.First()]); Z* output1 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer2 = &(bufferX[movement.First()]); Z* output2 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer3 = &(bufferX[movement.First()]); Z* output3 = &(outputZ[movement.Second()]); movement.increment(); indexInnerReductionRank1Block4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total, inner_stride); } if (inner_total >= 2048) { for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionRank1Block4WithMerge(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride); movement.increment(); } } else { for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionRank1(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride); movement.increment(); } } } else { Nd4jLong inner_last; Nd4jLong inner_loop = getLength(inner_bases, second_rank, 1, inner_last); if (second_rank == 2) { LOG_CALLS(11) for (Nd4jLong i = 0; i < loopTotal_K; i++) { const X* buffer0 = &(bufferX[movement.First()]); Z* output0 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer1 = &(bufferX[movement.First()]); Z* output1 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer2 = &(bufferX[movement.First()]); Z* output2 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer3 = &(bufferX[movement.First()]); Z* output3 = &(outputZ[movement.Second()]); movement.increment(); indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, inner_loop, inner_last, inner_stride); } for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last, inner_stride); movement.increment(); } } else if (second_rank == 3) { LOG_CALLS(12) for (Nd4jLong i = 0; i < loopTotal_K; i++) { const X* buffer0 = &(bufferX[movement.First()]); Z* output0 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer1 = &(bufferX[movement.First()]); Z* output1 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer2 = &(bufferX[movement.First()]); Z* output2 = &(outputZ[movement.Second()]); movement.increment(); const X* buffer3 = &(bufferX[movement.First()]); Z* output3 = &(outputZ[movement.Second()]); movement.increment(); indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, inner_loop, inner_last, inner_stride); } for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last, inner_stride); movement.increment(); } } else { LOG_CALLS(13) //nd4j_printf("-------%d inner loop %d inner_last %d\n", loopTotal, inner_loop,inner_last); for (Nd4jLong i = 0; i < loopTotal; i++) { X current; const X* buffer0 = &(bufferX[movement.First()]); indexInnerReduction(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0, inner_loop, inner_last, inner_stride); movement.increment(); } } } } } template void argIndexCaseNonScalar(const int& first_rank, const int& output_rank, bool squashed, const int& second_rank, const Nd4jLong*& outer_bases,const Nd4jLong* outer_strides,const Nd4jLong* output_strides, const Nd4jLong &output_stride, const Nd4jLong*& inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ) { Nd4jLong total = getLength(outer_bases, first_rank); Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0]; Nd4jLong outer_stride = LastIndexFaster ? outer_strides[second_rank - 1] : outer_strides[0]; auto func = [first_rank, output_rank, squashed, outer_bases, outer_strides, output_strides, output_stride, second_rank, inner_bases, inner_strides, bufferX, outputZ](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { Nd4jLong loopTotal = stop - start; Nd4jLong stride = LastIndexFaster ? outer_strides[first_rank - 1] : outer_strides[0]; if (first_rank == 1) { if (stride == 1) { ZipGenericCoordsRank1Stride1 movement; movement.init(nullptr, nullptr, nullptr, 0, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } else { ZipGenericCoordsRank1BothStrideN movement; movement.init(nullptr, &stride, &output_stride, 0, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } } else if (squashed && first_rank <= output_rank) { if (first_rank == 2) { if (output_stride == 1) { ZipGenericCoordsConstMovementSecondStride1<2, LastIndexFaster> movement; movement.init(outer_bases, outer_strides, nullptr, first_rank, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } else { ZipGenericCoordsConstMovementSecondStrideN<2, LastIndexFaster> movement; movement.init(outer_bases, outer_strides, &output_stride, first_rank, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } } else if (first_rank == 3) { if (output_stride == 1) { ZipGenericCoordsConstMovementSecondStride1<3, LastIndexFaster> movement; movement.init(outer_bases, outer_strides, nullptr, first_rank, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } else { ZipGenericCoordsConstMovementSecondStrideN<3, LastIndexFaster> movement; movement.init(outer_bases, outer_strides, &output_stride, first_rank, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } } else { ZipGenericCoordsMovementSecondStrideN< LastIndexFaster> movement; movement.init(outer_bases, outer_strides, &output_stride, first_rank, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } } else { ZipGenericCoordsMovement movement; movement.init(outer_bases, outer_strides, output_strides, first_rank, start); argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); } }; #if 0 func(0, 0, total, 1); #else // uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads(); Nd4jLong inner_total = getLength(inner_bases, second_rank); if (total * inner_total <= threadingThreshold) { numThreads = 1; } else { if (inner_stride > outer_stride && total <= 256) { auto desired = total > 4 ? (total / 4) : 1; numThreads = numThreads > desired ? desired : numThreads; } } samediff::Threads::parallel_tad(func, 0, total, 1, numThreads); #endif } template void argIndex_(const NDArray& input, NDArray& output, const std::vector& dimensions) { char input_order = input.ordering(); bool try_squash_outer = (input_order == output.ordering()) && output.ews() != 0; const Nd4jLong* input_shapeInfo = input.shapeInfo(); const Nd4jLong* output_shapeInfo = output.shapeInfo(); const Nd4jLong rank = input_shapeInfo[0]; const Nd4jLong* input_bases = &(input_shapeInfo[1]); const Nd4jLong* input_strides = &(input_shapeInfo[rank + 1]); const Nd4jLong output_rank = output_shapeInfo[0]; const Nd4jLong* output_strides = &(output_shapeInfo[output_rank + 1]); Nd4jLong new_bases[MAX_RANK]; Nd4jLong new_strides[MAX_RANK]; int first_begin, first_end, second_begin, second_end; //rePartition into two parts based on the selection rePartition(input_order, dimensions, rank, input_bases, input_strides, new_bases, new_strides, first_begin, first_end, second_begin, second_end, try_squash_outer, input_order == 'c'); int first_rank = first_end - first_begin; //the first rank can be 0 for scalar cases int second_rank = second_end - second_begin; auto bufferX = input.bufferAsT(); auto outputZ = output.bufferAsT(); const Nd4jLong* outer_bases = &(new_bases[first_begin]); const Nd4jLong* outer_strides = &(new_strides[first_begin]); const Nd4jLong* inner_bases = &(new_bases[second_begin]); const Nd4jLong* inner_strides = &(new_strides[second_begin]); const Nd4jLong output_stride = output.ordering() == 'c' ? output_strides[output_rank-1]:output_strides[0]; if (input_order == 'c') { if (first_rank == 0) { argIndexCase1Scalar(second_rank, inner_bases, inner_strides, bufferX, outputZ); } else { argIndexCaseNonScalar(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides, output_stride,inner_bases, inner_strides, bufferX, outputZ); } } else { if (first_rank == 0) { LOG_CALLS(0); if (second_rank == 1) { argIndexCase1Scalar(second_rank, inner_bases, inner_strides, bufferX, outputZ); } else { argIndexCase1Scalar(second_rank, inner_bases, inner_strides, bufferX, outputZ); } } else { LOG_CALLS(1); argIndexCaseNonScalar(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides, output_stride, inner_bases, inner_strides, bufferX, outputZ); } } } template struct IndexMax { static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { if (candidate > current) { current = candidate; currentIndex = candidateIndex; } } }; template struct IndexMin { static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { if (candidate < current) { current = candidate; currentIndex = candidateIndex; } } }; template struct IndexAbsMax { static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { auto absCandidate = sd::math::nd4j_abs(candidate); if (absCandidate > current) { current = absCandidate; currentIndex = candidateIndex; } } }; template struct IndexAbsMin { static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { auto absCandidate = sd::math::nd4j_abs(candidate); if (absCandidate < current) { current = absCandidate; currentIndex = candidateIndex; } } }; ////////////////////////////////////////////////////////////////////////// template void argMax_(const NDArray& input, NDArray& output, const std::vector& dimensions) { return argIndex_>(input, output, dimensions); } template void argMin_(const NDArray& input, NDArray& output, const std::vector& dimensions) { return argIndex_>(input, output, dimensions); } template void argAbsMax_(const NDArray& input, NDArray& output, const std::vector& dimensions) { return argIndex_>(input, output, dimensions); } template void argAbsMin_(const NDArray& input, NDArray& output, const std::vector& dimensions) { return argIndex_>(input, output, dimensions); } } } }