/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/
 //
 // @author AbdelRauf 
 //
#include <type_traits>
#include <cmath>
#include <stdexcept>
#include <memory>
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <helpers/LoopsCoordsHelper.h>
#include <ops/declarable/helpers/reductions.h>
#if 1
#define  LOG_CALLS(X) 
#else
 
#define  LOG_CALLS(X)  nd4j_printf("___%s_________%d+\n", __PRETTY_FUNCTION__, X); 
#endif
namespace sd {
	namespace ops {
		namespace helpers {
			constexpr int threadingThreshold = 4096;
			template<typename X, typename Z, typename ReductionOp>
			FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
			{
				argCurrent = 0;
				current = buffer[0];
				LOG_CALLS(0)
				Nd4jLong j_offset = 0;
				for (Z j = 0; j < loopCount; j++) {
					ReductionOp::update(current, argCurrent, buffer[j], j);
				}
			}

			template<typename X, typename Z, typename ReductionOp>
			FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
			{
				argCurrent = 0;
				current = buffer[0];
				LOG_CALLS(0)
				Nd4jLong j_offset = 0;
				for (Z j = 0; j < loopCount; j++) {
					ReductionOp::update(current, argCurrent, buffer[j_offset], j);
					j_offset += inner_stride;
				}
			}

			template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
			FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount)
			{
				//skip 1 from the beginning or end depending the Order 
				constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
				constexpr size_t updated_rank = constRank - 1;
				sd::CoordsState<updated_rank - 1> cst;
				//we skip 1  
				size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
				Z startIndex = 0;
				argCurrent = 0;
				current = buffer[offset];
				LOG_CALLS(0)
				for (Z i = 0; i < outerLoopCount; i++) {
					const X* inner_buffer = &(buffer[offset]);
					//typename std::make_signed<Z>::type iArgMax = -1;
					for (Z j = 0; j < innerLoopCount; j++) {
						ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
					}
					//we skip 1
					offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
					startIndex += innerLoopCount;
				}
			}

			template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
			FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
			{
				//skip 1 from the beginning or end depending the Order 
				constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
				constexpr size_t updated_rank = constRank - 1;
				sd::CoordsState<updated_rank - 1> cst;
				//we skip 1  
				size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
				Z startIndex = 0;
				argCurrent = 0;
				current = buffer[offset];
				LOG_CALLS(0)
				for (Z i = 0; i < outerLoopCount; i++) {
					const X* inner_buffer = &(buffer[offset]);
					for (Z j = 0; j < innerLoopCount; j++) {
						ReductionOp::update(current, argCurrent, *inner_buffer, j + startIndex);
						inner_buffer += inner_stride;
					}
					//we alreaddy skiped
					offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
					startIndex += innerLoopCount;
				}
			}

			template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
			FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount)
			{
				size_t offset = 0;
				Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
				Nd4jLong coords[MAX_RANK] = {};
				Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
				if (outerLoopStart > 0) {
					sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
					offset = sd::offset_from_coords(strides, ptr_coords, rank);
				}
				Z startIndex = outerLoopStart * innerLoopCount;
				argCurrent = startIndex;
				current = buffer[offset];
				LOG_CALLS(0)
				for (Z i = 0; i < outerLoopCount; i++) {
					const X* inner_buffer = &(buffer[offset]);
					//typename std::make_signed<Z>::type iArgMax = -1;
					for (Z j = 0; j < innerLoopCount; j++) {
						//nd4j_printf("%f\n", inner_buffer[j]);
						ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
					}
					offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
					//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
					startIndex += innerLoopCount;
				}
			}

			template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
			FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
			{
				size_t offset = 0;
				Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
				Nd4jLong coords[MAX_RANK] = {};
				Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
				if (outerLoopStart > 0) {
					sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
					offset = sd::offset_from_coords(strides, ptr_coords, rank);
				}
				Z startIndex = outerLoopStart * innerLoopCount;
				argCurrent = startIndex;
				current = buffer[offset];
				LOG_CALLS(0)
				for (Z i = 0; i < outerLoopCount; i++) {
					const X* inner_buffer = &(buffer[offset]);
					//typename std::make_signed<Z>::type iArgMax = -1;
					for (Z j = 0; j < innerLoopCount; j++) {
						ReductionOp::update(current, argCurrent, inner_buffer[j * inner_stride], startIndex + j);
					}
					offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
					//offset = inc_coords<LastIndexFaster>(bases, strides, ptr_coords, offset, rank, 1);
					//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
					startIndex += innerLoopCount;
				}
			}

			template<typename X, typename Z, typename ReductionOp>
			FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
			{
				argCurrent = 0;
				current = buffer[0];
				LOG_CALLS(0)
				Nd4jLong loopCount4 = loopCount / 4;
				Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
				const X* buffer1 = buffer + 1 * loopCount4;
				const X* buffer2 = buffer1 + 1 * loopCount4;
				const X* buffer3 = buffer2 + 1 * loopCount4;
				X current1 = *buffer1;
				X current2 = *buffer2;
				X current3 = *buffer3;
				Z argCurrent1 = 0;
				Z argCurrent2 = 0;
				Z argCurrent3 = 0;
				for (Z j = 0; j < loopCount4; j++) {
					ReductionOp::update(current, argCurrent, buffer[j], j);
					ReductionOp::update(current1, argCurrent1, buffer1[j], j);
					ReductionOp::update(current2, argCurrent2, buffer2[j], j);
					ReductionOp::update(current3, argCurrent3, buffer3[j], j);
				}
				//tail
				for (Z j = loopCount4; j < loopCountEnd; j++) {
					ReductionOp::update(current3, argCurrent3, buffer3[j], j);
				}
				//merge
				argCurrent1 += loopCount4;
				argCurrent2 += 2 * loopCount4;
				argCurrent3 += 3 * loopCount4;
				ReductionOp::update(current, argCurrent, current1, argCurrent1);
				ReductionOp::update(current, argCurrent, current2, argCurrent2);
				ReductionOp::update(current, argCurrent, current3, argCurrent3);
			}

			template<typename X, typename Z, typename ReductionOp>
			FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
			{
				argCurrent = 0;
				current = buffer[0];
				LOG_CALLS(0)
				Nd4jLong loopCount4 = loopCount / 4;
				Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
				const X* buffer1 = buffer + inner_stride * loopCount4;
				const X* buffer2 = buffer1 + inner_stride * loopCount4;
				const X* buffer3 = buffer2 + inner_stride * loopCount4;
				X current1 = *buffer1;
				X current2 = *buffer2;
				X current3 = *buffer3;
				Z argCurrent1 = 0;
				Z argCurrent2 = 0;
				Z argCurrent3 = 0;
				Nd4jLong j_offset = 0;
				for (Z j = 0; j < loopCount4; j++) {
					ReductionOp::update(current, argCurrent, buffer[j_offset], j);
					ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
					ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
					ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
					j_offset += inner_stride;
				}
				//tail
				for (Z j = loopCount4; j < loopCountEnd; j++) {
					ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
					j_offset += inner_stride;
				}
				//merge
				argCurrent1 += loopCount4;
				argCurrent2 += 2 * loopCount4;
				argCurrent3 += 3 * loopCount4;
				ReductionOp::update(current, argCurrent, current1, argCurrent1);
				ReductionOp::update(current, argCurrent, current2, argCurrent2);
				ReductionOp::update(current, argCurrent, current3, argCurrent3);
			}

			template<typename X, typename Z, typename ReductionOp>
			FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount)
			{
				LOG_CALLS(0)
				Z argCurrent = 0;
				Z argCurrent1 = 0;
				Z argCurrent2 = 0;
				Z argCurrent3 = 0;
				X current = buffer[0];
				X current1 = buffer1[0];
				X current2 = buffer2[0];
				X current3 = buffer3[0];
				for (Z j = 0; j < loopCount; j++) {
					ReductionOp::update(current, argCurrent, buffer[j], j);
					ReductionOp::update(current1, argCurrent1, buffer1[j], j);
					ReductionOp::update(current2, argCurrent2, buffer2[j], j);
					ReductionOp::update(current3, argCurrent3, buffer3[j], j);
				}
				*output = argCurrent;
				*output1 = argCurrent1;
				*output2 = argCurrent2;
				*output3 = argCurrent3;
				return;
			}

			template<typename X, typename Z, typename ReductionOp>
			FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
			{
				LOG_CALLS(0)
				Z argCurrent = 0;
				Z argCurrent1 = 0;
				Z argCurrent2 = 0;
				Z argCurrent3 = 0;
				X current = buffer[0];
				X current1 = buffer1[0];
				X current2 = buffer2[0];
				X current3 = buffer3[0];
				Nd4jLong j_offset = 0;
				for (Z j = 0; j < loopCount; j++) {
					ReductionOp::update(current, argCurrent, buffer[j_offset], j);
					ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
					ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
					ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
					j_offset += inner_stride;
				}
				*output = argCurrent;
				*output1 = argCurrent1;
				*output2 = argCurrent2;
				*output3 = argCurrent3;
				return;
			}

			template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
			FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
				Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
				const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount)
			{
				LOG_CALLS(0)
				//skip 1 from the beginning or end depending the Order 
				constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
				constexpr size_t updated_rank = constRank - 1;
				sd::CoordsState<updated_rank - 1> cst;
				//we skip 1  
				size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
				Z startIndex = 0;
				Z argCurrent = 0;
				Z argCurrent1 = 0;
				Z argCurrent2 = 0;
				Z argCurrent3 = 0;
				X current = buffer[0];
				X current1 = buffer1[0];
				X current2 = buffer2[0];
				X current3 = buffer3[0];
				//LOG_CALLS(0)
				for (Z i = 0; i < outerLoopCount; i++) {
					const X* inner_buffer = &(buffer[offset]);
					const X* inner_buffer1 = &(buffer1[offset]);
					const X* inner_buffer2 = &(buffer2[offset]);
					const X* inner_buffer3 = &(buffer3[offset]);
					//typename std::make_signed<Z>::type iArgMax = -1; 
					for (Z j = 0; j < innerLoopCount; j++) {
						ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
						ReductionOp::update(current1, argCurrent1, inner_buffer1[j], j + startIndex);
						ReductionOp::update(current2, argCurrent2, inner_buffer2[j], j + startIndex);
						ReductionOp::update(current3, argCurrent3, inner_buffer3[j], j + startIndex);
					}
					//we skip 1
					offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
					startIndex += innerLoopCount;
				}
				*output = argCurrent;
				*output1 = argCurrent1;
				*output2 = argCurrent2;
				*output3 = argCurrent3;
				return;
			}

			template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
			FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
				Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
				const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
			{
				LOG_CALLS(0)
				//skip 1 from the beginning or end depending the Order 
				constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
				constexpr size_t updated_rank = constRank - 1;
				sd::CoordsState<updated_rank - 1> cst;
				//we skip 1  
				size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
				Z startIndex = 0;
				Z argCurrent = 0;
				Z argCurrent1 = 0;
				Z argCurrent2 = 0;
				Z argCurrent3 = 0;
				X current = buffer[0];
				X current1 = buffer1[0];
				X current2 = buffer2[0];
				X current3 = buffer3[0];
				//LOG_CALLS(0)
				for (Z i = 0; i < outerLoopCount; i++) {
					const X* inner_buffer = &(buffer[offset]);
					const X* inner_buffer1 = &(buffer1[offset]);
					const X* inner_buffer2 = &(buffer2[offset]);
					const X* inner_buffer3 = &(buffer3[offset]);
					//typename std::make_signed<Z>::type iArgMax = -1;
					Nd4jLong inner_offset = 0;
					for (Z j = 0; j < innerLoopCount; j++) {
						ReductionOp::update(current, argCurrent, inner_buffer[inner_offset], j + startIndex);
						ReductionOp::update(current1, argCurrent1, inner_buffer1[inner_offset], j + startIndex);
						ReductionOp::update(current2, argCurrent2, inner_buffer2[inner_offset], j + startIndex);
						ReductionOp::update(current3, argCurrent3, inner_buffer3[inner_offset], j + startIndex);
						inner_offset += inner_stride;
					}
					//we skip 1
					offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
					startIndex += innerLoopCount;
				}
				*output = argCurrent;
				*output1 = argCurrent1;
				*output2 = argCurrent2;
				*output3 = argCurrent3;
				return;
			}

			template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
			void argIndexCase1Scalar(const  int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const  X* bufferX, Z* outputZ)
			{
				Nd4jLong inner_total;
				Nd4jLong inner_last = 0;
				int maxThreads = sd::Environment::getInstance().maxMasterThreads();
				if (second_rank == 1) {
					inner_total = inner_bases[0]; 
					if (inner_total  < threadingThreshold) {
						maxThreads = 1;
					}
				}
				else {
					inner_total = getLength<LastIndexFaster>(inner_bases, second_rank, 1, inner_last);
					if (inner_total * inner_last < threadingThreshold) {
						maxThreads = 1;
					}
				}

				

				std::unique_ptr<X[]> maxValues(new X[maxThreads]);
				std::unique_ptr<Z[]> maxIndices(new Z[maxThreads]);
				X* ptrMaxValues = maxValues.get();
				Z* ptrMaxIndices = maxIndices.get();
				auto func = [ptrMaxValues, ptrMaxIndices, inner_last, second_rank, inner_bases, inner_strides, bufferX](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
					//LOG_CALLS(0)
					const Nd4jLong inner_stride = LastIndexFaster ? inner_strides[second_rank - 1] : inner_strides[0];
					Z argCurrent; X current;
					if (second_rank == 1) {
						const Nd4jLong loopTotal = stop - start;
						if (inner_stride == 1) {
							indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start]), current, argCurrent, loopTotal);
						}
						else {
							indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start * inner_stride]), current, argCurrent, loopTotal, inner_stride);
						}
						ptrMaxIndices[thread_id] = argCurrent + start;
					}
					else {
						if (inner_stride == 1) {
							indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
						}
						else {
							indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
						}
						ptrMaxIndices[thread_id] = argCurrent;
					}
					ptrMaxValues[thread_id] = current;
				};
#if 0
				int Count = 0;
				func(0, 0, inner_total, 1);
#else
				int Count = samediff::Threads::parallel_tad(func, 0, inner_total, 1, maxThreads);
#endif
				Z arg = 0;
				X current = ptrMaxValues[0];

				for (Z i = 1; i < Count; i++) {
					ReductionOp::update(current, arg, ptrMaxValues[i], i);
				}

				*outputZ = ptrMaxIndices[arg];
			}


			template<typename X, typename Z, typename ReductionOp, typename Movement, bool LastIndexFaster = true>
			void argReductionInnerCases(Movement& movement, Nd4jLong loopTotal, const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
			{

				Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];

				Nd4jLong loopTotal_K = loopTotal / 4;
				Nd4jLong loopTotal_Tail = loopTotal & 3;
				if (inner_stride == 1) {
					if (second_rank == 1) {
						LOG_CALLS(0)
						Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
						for (Nd4jLong i = 0; i < loopTotal_K; i++) {
							const X* buffer0 = &(bufferX[movement.First()]);
							Z* output0 = &(outputZ[movement.Second()]);
							movement.increment();
							const X* buffer1 = &(bufferX[movement.First()]);
							Z* output1 = &(outputZ[movement.Second()]);
							movement.increment();
							const X* buffer2 = &(bufferX[movement.First()]);
							Z* output2 = &(outputZ[movement.Second()]);
							movement.increment();
							const X* buffer3 = &(bufferX[movement.First()]);
							Z* output3 = &(outputZ[movement.Second()]);
							movement.increment();
							indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total);

						}
						if (inner_total >= 2048) {
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
								movement.increment();
							}
						}
						else {
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
								movement.increment();
							}
						}

					}
					else {
						Nd4jLong inner_last;
						Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
						if (second_rank == 2) {
							LOG_CALLS(1)
							for (Nd4jLong i = 0; i < loopTotal_K; i++) {
								const X* buffer0 = &(bufferX[movement.First()]);
								Z* output0 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer1 = &(bufferX[movement.First()]);
								Z* output1 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer2 = &(bufferX[movement.First()]);
								Z* output2 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer3 = &(bufferX[movement.First()]);
								Z* output3 = &(outputZ[movement.Second()]);
								movement.increment();
								indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
									inner_loop, inner_last);

							}
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last);
								movement.increment();
							}

						}
						else if (second_rank == 3) {
							LOG_CALLS(2)
							for (Nd4jLong i = 0; i < loopTotal_K; i++) {
								const X* buffer0 = &(bufferX[movement.First()]);
								Z* output0 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer1 = &(bufferX[movement.First()]);
								Z* output1 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer2 = &(bufferX[movement.First()]);
								Z* output2 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer3 = &(bufferX[movement.First()]);
								Z* output3 = &(outputZ[movement.Second()]);
								movement.increment();
								indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
									inner_loop, inner_last);

							}
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
									inner_loop, inner_last);
								movement.increment();
							}

						}
						else {
							LOG_CALLS(3)
							//nd4j_printf("-----%d \n", loopTotal);
							for (Nd4jLong i = 0; i < loopTotal; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
									inner_loop, inner_last);
								movement.increment();
							}

						}
					}

				}
				else {
					if (second_rank == 1) {
						LOG_CALLS(10)
						Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
						for (Nd4jLong i = 0; i < loopTotal_K; i++) {
							const X* buffer0 = &(bufferX[movement.First()]);
							Z* output0 = &(outputZ[movement.Second()]);
							movement.increment();
							const X* buffer1 = &(bufferX[movement.First()]);
							Z* output1 = &(outputZ[movement.Second()]);
							movement.increment();
							const X* buffer2 = &(bufferX[movement.First()]);
							Z* output2 = &(outputZ[movement.Second()]);
							movement.increment();
							const X* buffer3 = &(bufferX[movement.First()]);
							Z* output3 = &(outputZ[movement.Second()]);
							movement.increment();
							indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total, inner_stride);

						}
						if (inner_total >= 2048) {
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
								movement.increment();
							}
						}
						else {
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
								movement.increment();
							}
						}

					}
					else {
						Nd4jLong inner_last;
						Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
						if (second_rank == 2) {
							LOG_CALLS(11)
							for (Nd4jLong i = 0; i < loopTotal_K; i++) {
								const X* buffer0 = &(bufferX[movement.First()]);
								Z* output0 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer1 = &(bufferX[movement.First()]);
								Z* output1 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer2 = &(bufferX[movement.First()]);
								Z* output2 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer3 = &(bufferX[movement.First()]);
								Z* output3 = &(outputZ[movement.Second()]);
								movement.increment();
								indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
									inner_loop, inner_last, inner_stride);

							}
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
									inner_loop, inner_last, inner_stride);
								movement.increment();
							}

						}
						else if (second_rank == 3) {
							LOG_CALLS(12)
							for (Nd4jLong i = 0; i < loopTotal_K; i++) {
								const X* buffer0 = &(bufferX[movement.First()]);
								Z* output0 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer1 = &(bufferX[movement.First()]);
								Z* output1 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer2 = &(bufferX[movement.First()]);
								Z* output2 = &(outputZ[movement.Second()]);
								movement.increment();
								const X* buffer3 = &(bufferX[movement.First()]);
								Z* output3 = &(outputZ[movement.Second()]);
								movement.increment();
								indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
									inner_loop, inner_last, inner_stride);

							}
							for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
									inner_loop, inner_last, inner_stride);
								movement.increment();
							}

						}
						else {
							LOG_CALLS(13)
							//nd4j_printf("-------%d inner loop %d inner_last %d\n", loopTotal, inner_loop,inner_last);
							for (Nd4jLong i = 0; i < loopTotal; i++) {
								X current;
								const X* buffer0 = &(bufferX[movement.First()]);
								indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
									inner_loop, inner_last, inner_stride);
								movement.increment();
							}

						}
					}

				}

			}

			template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
			void argIndexCaseNonScalar(const  int& first_rank, const int& output_rank, bool squashed, const  int& second_rank,
				const Nd4jLong*& outer_bases,const Nd4jLong* outer_strides,const Nd4jLong* output_strides, const Nd4jLong &output_stride,
				const Nd4jLong*& inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
			{

				Nd4jLong total = getLength<LastIndexFaster>(outer_bases, first_rank);
				Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];
				Nd4jLong outer_stride =  LastIndexFaster  ? outer_strides[second_rank - 1] : outer_strides[0];
				auto func = [first_rank, output_rank, squashed, outer_bases, outer_strides, output_strides, output_stride, second_rank, inner_bases, inner_strides, bufferX, outputZ](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {

					Nd4jLong loopTotal = stop - start;
					Nd4jLong stride = LastIndexFaster ? outer_strides[first_rank - 1] : outer_strides[0];
					if (first_rank == 1) {

						if (stride == 1) {
							ZipGenericCoordsRank1Stride1 movement;
							movement.init(nullptr, nullptr, nullptr, 0, start);
							argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
						}
						else {
							ZipGenericCoordsRank1BothStrideN movement;
							movement.init(nullptr, &stride, &output_stride, 0, start);
							argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);

						}

					}
					else if (squashed && first_rank <= output_rank) {
						if (first_rank == 2) {
							if (output_stride == 1) {
								ZipGenericCoordsConstMovementSecondStride1<2, LastIndexFaster> movement;
								movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
								argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);

							}
							else {
								ZipGenericCoordsConstMovementSecondStrideN<2, LastIndexFaster> movement;
								movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
								argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);

							}
						}
						else if (first_rank == 3) {
							if (output_stride == 1) {
								ZipGenericCoordsConstMovementSecondStride1<3, LastIndexFaster> movement;
								movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
								argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);

							}
							else {
								ZipGenericCoordsConstMovementSecondStrideN<3, LastIndexFaster> movement;
								movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
								argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);

							}
						}
						else {
							ZipGenericCoordsMovementSecondStrideN< LastIndexFaster> movement;
							movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);

							argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);

						}

					}
					else { 
						ZipGenericCoordsMovement<LastIndexFaster> movement;
						movement.init(outer_bases, outer_strides, output_strides, first_rank, start);

						argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);

					}

				};
#if 0
				func(0, 0, total, 1);
#else
				//
				uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads();
			    Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
				if (total * inner_total <= threadingThreshold) {
						numThreads = 1;
				}
				else {
					if (inner_stride > outer_stride && total <= 256) {
						auto desired = total > 4 ? (total / 4) : 1;
						numThreads = numThreads > desired ? desired : numThreads;
					}
				}
				 
				samediff::Threads::parallel_tad(func, 0, total, 1, numThreads);
#endif
			}

			template<typename X, typename Z, typename ReductionOp>
			void  argIndex_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
				char input_order = input.ordering();
				bool try_squash_outer = (input_order == output.ordering()) && output.ews() != 0;
				const Nd4jLong* input_shapeInfo = input.shapeInfo();
				const Nd4jLong* output_shapeInfo = output.shapeInfo();
				const Nd4jLong  rank = input_shapeInfo[0];
				const Nd4jLong* input_bases = &(input_shapeInfo[1]);
				const Nd4jLong* input_strides = &(input_shapeInfo[rank + 1]);
				const Nd4jLong  output_rank = output_shapeInfo[0];
				const Nd4jLong* output_strides = &(output_shapeInfo[output_rank + 1]);
				Nd4jLong new_bases[MAX_RANK];
				Nd4jLong new_strides[MAX_RANK];
				int first_begin, first_end, second_begin, second_end;
				//rePartition into two parts based on the selection
				rePartition(input_order, dimensions, rank, input_bases, input_strides, new_bases, new_strides, first_begin, first_end, second_begin, second_end, try_squash_outer, input_order == 'c');
				int first_rank = first_end - first_begin; //the first rank can be 0 for scalar cases
				int second_rank = second_end - second_begin;
				auto bufferX = input.bufferAsT<X>();
				auto outputZ = output.bufferAsT<Z>();
				const Nd4jLong* outer_bases = &(new_bases[first_begin]);
				const Nd4jLong* outer_strides = &(new_strides[first_begin]);
				const Nd4jLong* inner_bases = &(new_bases[second_begin]);
				const Nd4jLong* inner_strides = &(new_strides[second_begin]);
				const Nd4jLong output_stride = output.ordering()  == 'c' ? output_strides[output_rank-1]:output_strides[0];
				if (input_order == 'c') {
					if (first_rank == 0) {
						argIndexCase1Scalar<X, Z, ReductionOp>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
					}
					else {
						argIndexCaseNonScalar<X, Z, ReductionOp>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
							output_stride,inner_bases, inner_strides, bufferX, outputZ);
					}
				}
				else {
					if (first_rank == 0) {
						LOG_CALLS(0);
						if (second_rank == 1) {
							argIndexCase1Scalar<X, Z, ReductionOp, false>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
						}
						else {
							argIndexCase1Scalar<X, Z, ReductionOp, true>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
						}
					}
					else {
						LOG_CALLS(1);
						argIndexCaseNonScalar<X, Z, ReductionOp,false>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
							output_stride, inner_bases, inner_strides, bufferX, outputZ);
					}
				}
			}

			template <typename X, typename Z>
			struct IndexMax {
				static FORCEINLINE void  update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
					if (candidate > current) {
						current = candidate;
						currentIndex = candidateIndex;
					}
				}
			};

			template <typename X, typename Z>
			struct IndexMin {
				static FORCEINLINE void  update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
					if (candidate < current) {
						current = candidate;
						currentIndex = candidateIndex;
					}
				}
			};

			template <typename X, typename Z>
			struct IndexAbsMax {
				static FORCEINLINE void  update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
					auto absCandidate = sd::math::nd4j_abs<X>(candidate);
					if (absCandidate > current) {
						current = absCandidate;
						currentIndex = candidateIndex;
					}
				}
			};

			template <typename X, typename Z>
			struct IndexAbsMin {
				static FORCEINLINE void  update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
					auto absCandidate = sd::math::nd4j_abs<X>(candidate);
					if (absCandidate < current) {
						current = absCandidate;
						currentIndex = candidateIndex;
					}
				}
			};

			
			//////////////////////////////////////////////////////////////////////////
			template<typename X, typename Z>
			void  argMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
				return argIndex_<X, Z, IndexMax<X, Z>>(input, output, dimensions);
			}

			template<typename X, typename Z>
			void  argMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
				return argIndex_<X, Z, IndexMin<X, Z>>(input, output, dimensions);
			}

			template<typename X, typename Z>
			void  argAbsMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
				return argIndex_<X, Z, IndexAbsMax<X, Z>>(input, output, dimensions);
			}

			template<typename X, typename Z>
			void  argAbsMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
				return argIndex_<X, Z, IndexAbsMin<X, Z>>(input, output, dimensions);
			}
		}
	}
}