/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

#pragma once
#include <ops/ops.h>
#include <loops/reduce_float.h>
#include <loops/reduce_same.h>
#include <loops/scalar.h>
#include <loops/indexreduce.h>
#include <loops/broadcasting.h>
#include <loops/transform_float.h>
#include <op_enums.h>
#include <loops/transform_strict.h>
#include <helpers/ConstantTadHelper.h>

#ifdef __CUDACC__
#include <loops/cuda/inplace_loops/reduce_same_inplace.h>
#include <loops/cuda/inplace_loops/transform_strict_inplace.h>
#include <loops/cuda/inplace_loops/scalar_inplace.h>
#endif

namespace functions {
	namespace broadcast {
		template <typename X, typename Y, typename Z>
		class Broadcast;
	}

	namespace transform {
		template <typename X>
		class TransformStrict;
	}

	namespace scalar {
	}

	namespace reduce {
		template <typename X, typename Z>
		class ReduceFloatFunction;

        template <typename X>
        class ReduceSameFunction;
	}
}

namespace simdOps {

	template<typename T, typename Z>
	class Pooling2D {
	public:
		static const bool requiresSpecial = true;
#ifdef __CUDACC__
		inline __host__ __device__
#elif defined(__GNUC__)

#endif
		static int outSize(int size, int k, int s, int p, bool coverAll) {
			if (coverAll)
				return (size + p * 2 - k + s - 1) / s + 1;
			else
				return (size + p * 2 - k) / s + 1;
		}

#ifdef __CUDACC__
		/**
		* Based on:  https://github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu
		*/

		static inline __device__ void execSpecialCuda(
			             T *dx, Nd4jLong *xShapeBuffer,
			             Z *result, Nd4jLong *zShapeBuffer,
			             Z *extraParams, 
                         int *allocationPointer, Z *reductionPointer, 
                         Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

			__shared__ int kH;
			__shared__ int kW;
			__shared__ int sH;
			__shared__ int sW;
			__shared__ int pH;
			__shared__ int pW;
			__shared__ int dH;
			__shared__ int dW;
			__shared__ int poolingMode;
			__shared__ Z extraParam0;

			__shared__ int batchSize;
			__shared__ int inChannels;
			__shared__ int outH;
			__shared__ int outW;
			__shared__ int inH;
			__shared__ int inW;

            //__shared__ int *strideIn;
            //__shared__ int *strideOut;
            __shared__ int strideB;
            __shared__ int strideC;
            __shared__ int strideY;
            __shared__ int strideX;

			__shared__ int strideOB;
            __shared__ int strideOC;
            __shared__ int strideOY;
            __shared__ int strideOX;

            __shared__ int length;
            __shared__ int kHEff;
            __shared__ int kWEff;
			__shared__ bool fOrder;
		

			if (threadIdx.x == 0) {
				kH = (int)extraParams[0];
				kW = (int)extraParams[1];
				sH = (int)extraParams[2];
				sW = (int)extraParams[3];
				pH = (int)extraParams[4];
				pW = (int)extraParams[5];
				dH = (int)extraParams[6];			//Dilation, height dimension
				dW = (int)extraParams[7];			//Dilation, width dimension
				poolingMode = (int)extraParams[9];
				extraParam0 = extraParams[10];

				batchSize = shape::sizeAt(xShapeBuffer, 0);
				inChannels = shape::sizeAt(xShapeBuffer, 1);
				outH = shape::sizeAt(zShapeBuffer, 2);
				outW = shape::sizeAt(zShapeBuffer, 3);
				inH = shape::sizeAt(xShapeBuffer, 2);
				inW = shape::sizeAt(xShapeBuffer, 3);

            	strideB = shape::stride(xShapeBuffer)[0];
            	strideC = shape::stride(xShapeBuffer)[1];
            	strideY = shape::stride(xShapeBuffer)[2];
            	strideX = shape::stride(xShapeBuffer)[3];

				strideOB = shape::stride(zShapeBuffer)[0];
            	strideOC = shape::stride(zShapeBuffer)[1];
            	strideOY = shape::stride(zShapeBuffer)[2];
            	strideOX = shape::stride(zShapeBuffer)[3];

            	length = shape::length(zShapeBuffer);

				//Replace kernel H/W with *effective* kernel H/W accounting for dilatyon
				kHEff = kH + (kH-1)*(dH-1);
				kWEff = kW + (kW-1)*(dW-1);

				fOrder = shape::order(zShapeBuffer) == 'f';
/*
				if (blockIdx.x == 0) {
					printf("kH: %i; kW: %i; sH: %i; sW: %i; pH: %i; pW: %i; dH: %i; dW: %i; poolingMode: %i; extraParam0: %f;\n", kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, (float) extraParam0);
					printf("batchSize: %i; inChannels: %i; outH: %i; outW: %i; inH: %i; inW: %i; strideB: %i; strideC: %i; strideY: %i; strideX: %i;\n", batchSize, inChannels, outH, outW, inH, inW, strideB, strideC, strideY, strideX);
				}
*/
            }
            __syncthreads();

			int tid = blockIdx.x * blockDim.x + threadIdx.x;

            for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
				const int pw = index % outW;
    			const int ph = (index / outW) % outH;
    			const int c = (index / outW / outH) % inChannels;
    			const int n = index / outW / outH / inChannels;
    			int hstart = sH * ph - pH;
    			int wstart = sW * pw - pW;
    			int hend = hstart + kHEff;
    			int wend = wstart + kWEff;

//    			const int hSO = hstart;
//    			const int hEO = hend;

    			if(hstart < 0){
                    int f = nd4j::math::nd4j_ceil<Z,int>((Z) -hstart / (Z)dH);
                    hstart += f * dH;
                }
                if(wstart < 0){
                    int f = nd4j::math::nd4j_ceil<Z,int>((Z) -wstart / (Z) dW);
                    wstart += f * dW;
                }
                if(hend > inH){
                    int f = nd4j::math::nd4j_ceil<Z,int>((Z) (hend-inH) / (Z) dH);
                    hend -= f * dH;
                }
                if(wend > inW){
                    int f = nd4j::math::nd4j_ceil<Z,int>((Z) (wend-inW) / (Z) dW);
                    wend -= f * dW;
                }
                //Accounts for dilation
    			int pool_size = nd4j::math::nd4j_ceil<double,int>((double) (hend-hstart) / (double) dH) * nd4j::math::nd4j_ceil<double,int>((double) (wend-wstart) / (double) dW);

    			Z sum = poolingMode == 0 ? -nd4j::DataTypeUtils::max<Z>() : static_cast<Z>(0.f);

    			T *input_slice = dx + (n * strideB + c * strideC);
    			if (poolingMode == 0) {
    			    for (int h = hstart; h < hend; h += dH) {
      				    for (int w = wstart; w < wend; w += dW) {
        				    Z v = static_cast<Z>(input_slice[h * strideY + w * strideX]);
        				    if (v > sum)
        				        sum = v;
      				    }
    			    }
    			} else if (poolingMode == 1) {
    			    for (int h = hstart; h < hend; h += dH) {
      				    for (int w = wstart; w < wend; w += dW) {
        				    sum += static_cast<Z>(input_slice[h * strideY + w * strideX]);
      				    }
    			    }
    			} else if (poolingMode == 2) {
    			    for (int h = hstart; h < hend; h += dH) {
      				    for (int w = wstart; w < wend; w += dW) {
        				    sum += nd4j::math::nd4j_pow<Z,Z,Z>(static_cast<Z>(nd4j::math::nd4j_abs<T>(input_slice[h * strideY + w * strideX])), extraParam0);
      				    }
    			    }
    			}

				Z res;

    			if (poolingMode == 0) {
                    res = sum;
    			} else if (poolingMode == 1) {
    			    int divide_factor = pool_size;  //Case 0: exclude padding
    			    if ((int) extraParam0 == 1)     //Case 1: include padding
					    divide_factor = kH * kW;

    			    res = sum / static_cast<Z>(divide_factor);
    			} else if (poolingMode == 2) {
                    res = nd4j::math::nd4j_pow<Z,Z,Z>(sum, (Z) 1.0f / extraParam0);
    			}


				if (!fOrder) {
					result[index] = res;
                } else {
					result[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = res;
                }
/*
                if (index >= 0 && index < 400000) {
    			    printf("index: %i; hstart: %i; hend: %i; wstart: %i; wend: %i; ph: %i; pw: %i; hstart_orig: %i; hend_orig: %i;\n", index, hstart, hend, wstart, wend, ph, pw, hSO, hEO);
    			}
*/
            }

            __syncthreads();
		}
#endif


static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outShapeBuffer, Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
	// input is  [bS, iC, iH, iW]
	// output is [bS, iC, oH, oW]

	const Nd4jLong kH = (int)extraParams[0];
	const Nd4jLong kW = (int)extraParams[1];
    const Nd4jLong sH = (int)extraParams[2];
    const Nd4jLong sW = (int)extraParams[3];
    const Nd4jLong pH = (int)extraParams[4];
    const Nd4jLong pW = (int)extraParams[5];    
    const Nd4jLong dH = (int)extraParams[6];
    const Nd4jLong dW = (int)extraParams[7];
    Nd4jLong poolingMode = (int)extraParams[9];
    T extraParam0 = extraParams[10];

    if(dH == 0 || dW == 0) {
       printf("Special_ops pooling2d:: dilation must not be zero, but got instead {%lld, %lld} \n", dH, dW);
       throw "";
    }

    const Nd4jLong kHEff = kH + (kH-1)*(dH-1);
    const Nd4jLong kWEff = kW + (kW-1)*(dW-1);

	const int bS = shape::sizeAt(inShapeBuffer, 0);
    const int iC = shape::sizeAt(inShapeBuffer, 1);
    const int iH = shape::sizeAt(inShapeBuffer, 2);
    const int iW = shape::sizeAt(inShapeBuffer, 3);
    const int oH = shape::sizeAt(outShapeBuffer, 2);
    const int oW = shape::sizeAt(outShapeBuffer, 3);            
    const Nd4jLong iStride0 = shape::stride(inShapeBuffer)[0];
    const Nd4jLong iStride1 = shape::stride(inShapeBuffer)[1];
    const Nd4jLong iStride2 = shape::stride(inShapeBuffer)[2];
    const Nd4jLong iStride3 = shape::stride(inShapeBuffer)[3];
    const Nd4jLong oStride0 = shape::stride(outShapeBuffer)[0];
    const Nd4jLong oStride1 = shape::stride(outShapeBuffer)[1];
    const Nd4jLong oStride2 = shape::stride(outShapeBuffer)[2];
    const Nd4jLong oStride3 = shape::stride(outShapeBuffer)[3];

    const Nd4jLong iStep2 = dH*iStride2;
    const Nd4jLong iStep3 = dW*iStride3;        
    const int kProd  = kH*kW;
    const T iStep2Inv = 1./iStep2; 
    const T iStep3Inv = 1./iStep3;

    Nd4jLong hstart, wstart, hend, wend;
    T sum, *pIn;

    if(poolingMode == 0) {        // max 
        PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2))
        for(int b = 0; b < bS; ++b) {
            for(int c = 0; c < iC; ++c) {                                                            
                for(int oh = 0; oh < oH; ++oh) {
                    for(int ow = 0; ow < oW; ++ow) {
                        
                        pIn  = in  + b * iStride0 + c * iStride1;
                        
                        hstart = oh * sH - pH;
                        wstart = ow * sW - pW;                        
                        hend = hstart + kHEff;
                        wend = wstart + kWEff;
                        
                        if(hstart < 0)
                            hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-hstart) / static_cast<T>(dH));
                        if(wstart < 0)
                            wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-wstart) / static_cast<T>(dW));
                        if(hend > iH)
                            hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(hend-iH) / static_cast<T>(dH));
                        if(wend > iW)
                            wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(wend-iW) / static_cast<T>(dW));

                        hstart *= iStride2;
                        hend   *= iStride2;
                        wstart *= iStride3;
                        wend   *= iStride3;

                        sum = -nd4j::DataTypeUtils::max<Z>();
                                                                    
                        for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) 
                            for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) {
                                T val = pIn[kh + kw];
                                    if (val > sum)
                                        sum = val;
                                    }
                        out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
                    }
                }
            }
        }    
    }
/*************************************************************************/    
    else if(poolingMode == 1) {      // avg
        PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2))
        for(int b = 0; b < bS; ++b) {
            for(int c = 0; c < iC; ++c) {                                                            
                for(int oh = 0; oh < oH; ++oh) {
                    for(int ow = 0; ow < oW; ++ow) {
                        
                        pIn  = in  + b * iStride0 + c * iStride1;

                        hstart = oh * sH - pH;
                        wstart = ow * sW - pW;
                        hend = hstart + kHEff;
                        wend = wstart + kWEff;

                        if(hstart < 0)
                            hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-hstart) / static_cast<T>(dH));
                        if(wstart < 0)
                            wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-wstart) / static_cast<T>(dW));
                        if(hend > iH)
                            hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(hend-iH) / static_cast<T>(dH));
                        if(wend > iW)
                            wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(wend-iW) / static_cast<T>(dW));

                        hstart *= iStride2;
                        hend   *= iStride2;
                        wstart *= iStride3;
                        wend   *= iStride3;

                        sum = static_cast<Z>(0.);
                                            
                        for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) 
                            for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
                                sum += pIn[kh + kw];
                                
                        if ((int) extraParam0 == 0)         //Exclude padding
                            sum /= static_cast<T>(nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend-hstart) / static_cast<double>(iStep2))) * static_cast<T>(nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend-wstart) / static_cast<double>(iStep3)));   //Accounts for dilation
                        else if ((int) extraParam0 == 1)    //Include padding
                            sum /= kProd;
                    
                        out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
                    }
                }
            }
        }
    }    
/*************************************************************************/    
    else if(poolingMode == 2) {  // pnorm
        PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, hstart, wstart, hend, wend) collapse(2))
        for(int b = 0; b < bS; ++b) {
            for(int c = 0; c < iC; ++c) {                                                            
                for(int oh = 0; oh < oH; ++oh) {
                    for(int ow = 0; ow < oW; ++ow) {
                        
                        pIn  = in  + b * iStride0 + c * iStride1;

                        hstart = oh * sH - pH;
                        wstart = ow * sW - pW;
                        hend = hstart + kHEff;
                        wend = wstart + kWEff;

                        if(hstart < 0)
                            hstart += dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-hstart) / static_cast<T>(dH));
                        if(wstart < 0)
                            wstart += dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(-wstart) / static_cast<T>(dW));
                        if(hend > iH)
                            hend -= dH * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(hend-iH) / static_cast<T>(dH));
                        if(wend > iW)
                            wend -= dW * (Nd4jLong)nd4j::math::nd4j_ceil<T,Nd4jLong>(static_cast<T>(wend-iW) / static_cast<T>(dW));

                        hstart *= iStride2;
                        hend   *= iStride2;
                        wstart *= iStride3;
                        wend   *= iStride3;

                        sum = static_cast<T>(0.);
                                                                    
                        for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) 
                            for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
                                sum += nd4j::math::nd4j_pow<T, T, T>(nd4j::math::nd4j_abs<T>(pIn[kh + kw]), extraParam0);
                                
                        sum = nd4j::math::nd4j_pow<T,T,T>(sum, (T) 1. / extraParam0);
                                                          
                        out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
                    }
                }
            }
        }
    }
    else {
        nd4j_printf("Special_ops::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
        throw "";
	}
}

		op_def static T op(T d1, Z *params) {
			return d1;
		}


		/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
		*  normally negative indices are bad, OK here because of other checks on input indices
		*  Uses unrolled loop specifically for length 4
		*/
		static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) {
			int offset = baseOffset;
			if (shape[0] != 1) offset += indices[0] * stride[0];
			if (shape[1] != 1) offset += indices[1] * stride[1];
			if (shape[2] != 1) offset += indices[2] * stride[2];
			if (shape[3] != 1) offset += indices[3] * stride[3];
			return offset;
		}


		/**
		* A version of Shape.getOffset without checking on input for negative indices etc
		* normally negative indices are bad, OK here because of other checks on input indices
		* Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here)
		*/
		static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) {
			int offset = baseOffset;
			if (shape[0] != 1) offset += indices[0] * stride[0];
			if (shape[1] != 1) offset += indices[1] * stride[1];
			if (shape[4] != 1) offset += indices[4] * stride[4];
			if (shape[5] != 1) offset += indices[5] * stride[5];
			return offset;
		}

	};


    FORCEINLINE bool is_a_ge_zero_and_a_lt_b(int a, int b) {
        return static_cast<unsigned>(a) < static_cast<unsigned>(b);
    }

	template<typename T>
	class 
	Im2col {
	public:
		static const bool requiresSpecial = true;

		static _CUDA_HD int outSize(int size, int k, int s, int p, bool coverAll) {
			if (coverAll)
				return (size + p * 2 - k + s - 1) / s + 1;
			else
				return (size + p * 2 - k) / s + 1;
		}

#ifdef __CUDACC__
		/**
		* Based on:  https://github.com/pjreddie/darknet/blob/master/src/im2col_kernels.cu
		*/

		static inline __device__ void execSpecialCuda(
			                             T *dx, Nd4jLong *xShapeBuffer,
			                             T *result, Nd4jLong *zShapeBuffer,
			                             T *extraParams, 
                                         int *allocationPointer, T *reductionPointer, 
                                         Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

			/*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/
			__shared__ int kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, kSize, samples, depth, height, width, strideex, stridech, strideh, stridew, height_col, width_col, n;
			__shared__ T zeroPadVal;
			__shared__ Nd4jLong *outShape, *outStride, *inShape, *inStride;
			__shared__ char resultOrder;

			if (threadIdx.x == 0) {
			    kernelHeight = (int) extraParams[0];
			    kernelWidth = (int) extraParams[1];
			    strideY = (int) extraParams[2];
			    strideX = (int) extraParams[3];
			    padHeight = (int) extraParams[4];
			    padWidth = (int) extraParams[5];
			    dY = (int) extraParams[6];			//Dilation, height/y dimension
			    dX = (int) extraParams[7];			//Dilation, width/x dimension
                kSize = kernelWidth * kernelHeight;
                zeroPadVal = (T) extraParams[9];	//Value to use when value is padding. Usually 0 but not always

                outShape = shape::shapeOf(zShapeBuffer);
                resultOrder = shape::order(zShapeBuffer);
			    outStride = shape::stride(zShapeBuffer);

			    inShape = shape::shapeOf(xShapeBuffer);
                inStride = shape::stride(xShapeBuffer);

                samples = (int) inShape[0];
                depth = (int) inShape[1];
                height = (int) inShape[2];
                width = (int) inShape[3];


                strideex = (int) inStride[0];
			    stridech = (int) inStride[1];
			    strideh = (int) inStride[2];
                stridew = (int) inStride[3];

			    // (height + 2 * padHeight - kernelHeight) / strideX + 1; //
			    // (width + 2 * padWidth - kernelWidth) / strideY + 1; //
			    height_col = (int) outShape[4];
			    width_col = (int) outShape[5];

			    n = samples * depth * height_col * width_col;
			}
			__syncthreads();

			int index = blockIdx.x * blockDim.x + threadIdx.x;
			for (; index < n; index += blockDim.x*gridDim.x) {
				int h_index = index / width_col;
				int h_col = h_index % height_col;
				int w_col = index % width_col;

				int c_im = h_index / height_col;
				int c_col = c_im * kSize;

				int depth_im = c_im % depth;
				int num_im = c_im / depth;
				int h_offset = h_col * strideY - padHeight;
				int w_offset = w_col * strideX - padWidth;

				T* data_col_ptr = result;

				int i_c = (c_col * height_col + h_col) * width_col + w_col;
				data_col_ptr += (c_col * height_col + h_col) * width_col + w_col;

				T* data_im_ptr = dx;

				data_im_ptr += num_im * strideex + depth_im * stridech + h_offset * strideh + w_offset*stridew;

				for (int i = 0; i < kernelHeight; ++i) {
					for (int j = 0; j < kernelWidth; ++j) {
						int h_im = h_offset + i * dY;
						int w_im = w_offset + j * dX;
						int i_f = 0;
						int i_c_temp = i_c;
						for (int dim = 5; dim >= 0; dim--) {
							i_f += (i_c_temp % outShape[dim])  * outStride[dim];
							i_c_temp = i_c_temp / outShape[dim];
						}
						if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width){
							result[i_f] = data_im_ptr[i * dY * strideh + j * dX * stridew];
						} else result[i_f] = zeroPadVal;

						//result[i_f] = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? data_im_ptr[i * strideh + j*stridew] : 0;
						data_col_ptr += height_col * width_col;
						i_c += height_col * width_col;
					}
				}
			}
		}
#endif


		static void execSpecial(
			T *imBuff,
			Nd4jLong *imShapeBuffer,
			T *colBuff,
			Nd4jLong *colShapeBuffer,
			T *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
			/*kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], 0, false*/

			// [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]        

			int kH = (int)extraParams[0];
			int kW = (int)extraParams[1];
			int sH = (int)extraParams[2];
			int sW = (int)extraParams[3];
			int pH = (int)extraParams[4];
			int pW = (int)extraParams[5];
			int dH = (int)extraParams[6];			//Dilation, height/y dimension
			int dW = (int)extraParams[7];			//Dilation, width/x dimension            
            T zeroPadVal = extraParams[9];

            auto colShape  = shape::shapeOf(colShapeBuffer);
            auto colStride = shape::stride(colShapeBuffer);
            auto imShape = shape::shapeOf(imShapeBuffer);
            auto imStride = shape::stride(imShapeBuffer);

            const int bS = imShape[0];
            const int iC = imShape[1];
            const int iH = imShape[2];
            const int iW = imShape[3];
            const int oH = colShape[4];
            const int oW = colShape[5];
            const Nd4jLong colStride0 = colStride[0];
            const Nd4jLong colStride1 = colStride[1];
            const Nd4jLong colStride2 = colStride[2];
            const Nd4jLong colStride3 = colStride[3];
            const Nd4jLong colStride4 = colStride[4];
            const Nd4jLong colStride5 = colStride[5];
            const Nd4jLong imStride0  = imStride[0];
            const Nd4jLong imStride1  = imStride[1];
            const Nd4jLong imStride2  = imStride[2];
            const Nd4jLong imStride3  = imStride[3];

            T *col, *im;
            int imRow, imCol;
            
            if (shape::order(imShapeBuffer) == 'c' &&  shape::order(colShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(imShapeBuffer) && shape::strideDescendingCAscendingF(colShapeBuffer)) {

                PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, im, imRow, imCol) collapse(2))
                for (int b = 0; b < bS; b++) {
                    for (int c = 0; c < iC; ++c) {        
                        for (int kRow = 0; kRow < kH; ++kRow) {                        
                            for (int kCol = 0; kCol < kW; ++kCol) {                            
                                for (int colH = 0; colH < oH; ++colH) {
                                    for (int colW = 0; colW < oW; ++colW) {                    
                                
                                        imRow = (-pH + kRow * dH) + colH*sH;
                                        imCol = (-pW + kCol * dW) + colW*sW;
                                        
                                        col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
                                        im  = imBuff  + b*imStride0  + c*imStride1  + imRow*imStride2 + imCol*imStride3; 
                                                    
                                        if (static_cast<unsigned>(imRow) >= static_cast<unsigned>(iH) || static_cast<unsigned>(imCol) >= static_cast<unsigned>(iW))
                                            *col = zeroPadVal;
                                        else 
                                            *col = *im;
                                    }
                                }
                            }
                        }
                    }
                }  
            }
            else {

                PRAGMA_OMP_PARALLEL_FOR_ARGS(private(im, col, imRow, imCol) collapse(2))
                for (int b = 0; b < bS; b++) {
                    for (int colH = 0; colH < oH; ++colH) {
                        for (int colW = 0; colW < oW; ++colW) {
                            for (int c = 0; c < iC; ++c) {
                                for (int kRow = 0; kRow < kH; ++kRow) {                        
                                    for (int kCol = 0; kCol < kW; ++kCol) {                            
                        
                                        imRow = (-pH + kRow * dH) + colH*sH;
                                        imCol = (-pW + kCol * dW) + colW*sW;
                                        
                                        col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
                                        im  = imBuff  + b*imStride0  + c*imStride1  + imRow*imStride2 + imCol*imStride3;
                                                    
                                        if (static_cast<unsigned>(imRow) >= static_cast<unsigned>(iH) || static_cast<unsigned>(imCol) >= static_cast<unsigned>(iW))
                                            *col = zeroPadVal;
                                        else 
                                            *col = *im;
                                    }
                                }
                            }
                        }
                    }
                }
            }
		}

		op_def static T op(T d1, T *params) {
			return d1;
		}


		/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
		*  normally negative indices are bad, OK here because of other checks on input indices
		*  Uses unrolled loop specifically for length 4
		*/
		static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) {
			int offset = baseOffset;
			if (shape[0] != 1) offset += indices[0] * stride[0];
			if (shape[1] != 1) offset += indices[1] * stride[1];
			if (shape[2] != 1) offset += indices[2] * stride[2];
			if (shape[3] != 1) offset += indices[3] * stride[3];
			return offset;
		}


		/**
		* A version of Shape.getOffset without checking on input for negative indices etc
		* normally negative indices are bad, OK here because of other checks on input indices
		* Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here)
		*/
		static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) {
			int offset = baseOffset;
			if (shape[0] != 1) offset += indices[0] * stride[0];
			if (shape[1] != 1) offset += indices[1] * stride[1];
			if (shape[4] != 1) offset += indices[4] * stride[4];
			if (shape[5] != 1) offset += indices[5] * stride[5];
			return offset;
		}

	};

	template<typename T, typename Z>
	class Histogram {
	public:
		static const bool requiresSpecial = true;

#ifdef __CUDACC__
		static inline __device__ void execSpecialCuda(
			                 T *dx, Nd4jLong *xShapeBuffer,
			                 Z *result, Nd4jLong *zShapeBuffer,
			                 Z *extraParams, 
                             int *allocationPointer, Z *reductionPointer, 
                             Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {



		};
#endif

		static void execSpecial(
				T *dx,
				Nd4jLong *xShapeBuffer,
				Z *result,
				Nd4jLong *zShapeBuffer,
				Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {



		}


        op_def static T op(T d1, Z *params) {
            return d1;
        }
	};

	template<typename X>
	class Col2Im {

	public:
		static const bool requiresSpecial = true;
#ifdef __CUDACC__
		/**
		* https://github.com/pjreddie/darknet/blob/master/src/col2im_kernels.cu
		*/

		static inline __device__ void execSpecialCuda(
			X *dx, Nd4jLong *xShapeBuffer,
			X *result, Nd4jLong *zShapeBuffer,
			X *extraParams, int *allocationPointer, 
            X *reductionPointer, 
            Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

		    __shared__ int strideex, stridech, stridekrow, stridekcol, striderow, stridecol, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX, samples, depth, imgH, imgW, height_col, width_col, n, kEffectiveW, kEffectiveH;
		    __shared__ Nd4jLong *inShape, *inStride, *outShape, *outStride;
		    __shared__ char resultOrder;

		    if (threadIdx.x == 0) {
			    inShape = shape::shapeOf(xShapeBuffer);
                inStride = shape::stride(xShapeBuffer);

			    strideex = (int) inStride[0];
                stridech = (int) inStride[1];
                stridekrow = (int) inStride[2];
                stridekcol = (int) inStride[3];
                striderow = (int) inStride[4];
                stridecol = (int) inStride[5];

			    kernelHeight = (int) inShape[2];
                kernelWidth = (int) inShape[3];

                strideY = (int) extraParams[0];
                strideX = (int) extraParams[1];
                padHeight = (int) extraParams[2];
			    padWidth = (int) extraParams[3];
                imgHeight = (int) extraParams[4];
                imgWidth = (int) extraParams[5];
                dY = (int) extraParams[6];			//Dilation in height/y dimension
                dX = (int) extraParams[7];			//Dilation in width/x dimension

			    outShape = shape::shapeOf(zShapeBuffer);
			    resultOrder = shape::order(zShapeBuffer);
			    outStride = shape::stride(zShapeBuffer);

                samples = (int) outShape[0];
                depth = (int) outShape[1];
                imgH = (int) outShape[2];
                imgW = (int) outShape[3];

                height_col = inShape[4];//(imgHeight + 2 * padHeight - kernelHeight) / strideX + 1;
			    width_col = inShape[5];//(imgWidth + 2 * padWidth - kernelWidth) / strideY + 1;

			    n = samples * depth * imgHeight * imgWidth;

			    //Effective kernel size, accounting for dilation
                kEffectiveW = kernelWidth + (kernelWidth - 1) * (dX - 1);
                kEffectiveH = kernelHeight + (kernelHeight - 1) * (dY - 1);
			}
		    __syncthreads();

			for (int i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
				X val = 0;
				int w_im = i % imgWidth + padWidth;
				int h_im = (i / imgWidth) % imgHeight + padHeight;
				int c_im = i / (imgWidth * imgHeight);

				int num_im = c_im / depth;
				int depth_im = c_im % depth;

				// compute the start and end of the output
				// These are the indexes for dimensions ??? in the 6d col matrix
				int w_col_start = (w_im < kEffectiveW) ? 0 : (w_im - kEffectiveW) / strideX + 1;
				int w_col_end = nd4j::math::nd4j_min<int>(w_im / strideX + 1, width_col);

				int h_col_start = (h_im < kEffectiveH) ? 0 : (h_im - kEffectiveH) / strideY + 1;
				int h_col_end = nd4j::math::nd4j_min<int>(h_im / strideY + 1, height_col);


				//Iterate over col entries in the 6d array... these are added up
				for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
					for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
						int h_k = (h_im - h_col * strideY);
						int w_k = (w_im - w_col * strideX);
						
						if(h_k % dY == 0 && w_k % dX == 0){
							h_k /= dY;
							w_k /= dX;

							int data_col_index = num_im * strideex + depth_im * stridech + h_k * stridekrow + w_k * stridekcol + h_col * striderow + w_col * stridecol;
							val += dx[data_col_index];
						}
					}
				}
				int i_f = 0;
				int i_c = i;
				for (int dim = 3; dim >= 0; dim--)
				{
					i_f += (i_c % outShape[dim])  * outStride[dim];
					i_c = i_c / outShape[dim];
				}
				result[i_f] = val;
			}
		}
#endif

		static void execSpecial(
			X *colBuff,
			Nd4jLong *colShapeBuffer,
			X *imBuff,
			Nd4jLong *imShapeBuffer,
			X *extraParams,
			Nd4jLong *tadShapeInfo,
			Nd4jLong *tadOffsets) {

            // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]

            auto colShape  = shape::shapeOf(colShapeBuffer);
            auto colStride = shape::stride(colShapeBuffer);
            auto imShape = shape::shapeOf(imShapeBuffer);
            auto imStride = shape::stride(imShapeBuffer);            

            const int sH = (int)extraParams[0];
            const int sW = (int)extraParams[1];
            const int pH = (int)extraParams[2];
            const int pW = (int)extraParams[3];
            const int iH = (int)extraParams[4];
            const int iW = (int)extraParams[5];
            const int dH = (int)extraParams[6];     
            const int dW = (int)extraParams[7];     

            const int bS = imShape[0];
            const int iC = imShape[1];
            const int kH = colShape[2];
            const int kW = colShape[3];                    
            const int oH = colShape[4];
            const int oW = colShape[5];
            const Nd4jLong colStride0 = colStride[0];
            const Nd4jLong colStride1 = colStride[1];
            const Nd4jLong colStride2 = colStride[2];
            const Nd4jLong colStride3 = colStride[3];
            const Nd4jLong colStride4 = colStride[4];
            const Nd4jLong colStride5 = colStride[5];
            const Nd4jLong imStride0  = imStride[0];
            const Nd4jLong imStride1  = imStride[1];
            const Nd4jLong imStride2  = imStride[2];
            const Nd4jLong imStride3  = imStride[3];

            auto zLength = shape::length(imShapeBuffer);

            // initial zeroing of image content
            memset(imBuff, 0, zLength * sizeof(X));


            X *col, *im;
            int imRow, imCol;

            if (shape::order(colShapeBuffer) == 'c' &&  shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) {

                PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, im, imRow, imCol) collapse(2))
                for (int b = 0; b < bS; b++) {        
                    for (int c = 0; c < iC; ++c) {                    
                        for (int kRow = 0; kRow < kH; ++kRow) {                        
                            for (int kCol = 0; kCol < kW; ++kCol) {                            
                                for (int colH = 0; colH < oH; ++colH) {
                                    for (int colW = 0; colW < oW; ++colW) {                    

                                        imRow = (-pH + kRow * dH) + colH*sH;
                                        imCol = (-pW + kCol * dW) + colW*sW;

                                        col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
                                        im  = imBuff  + b*imStride0  + c*imStride1  + imRow*imStride2 + imCol*imStride3;

                                        if (static_cast<unsigned>(imRow) < static_cast<unsigned>(iH) && static_cast<unsigned>(imCol) < static_cast<unsigned>(iW))
                                            *im += *col;
                                    }
                                }
                            }
                        }
                    }
                }  
            }
            else {

                PRAGMA_OMP_PARALLEL_FOR_ARGS(private(im, col, imRow, imCol))
                for (int b = 0; b < bS; b++) {        
                    for (int colH = 0; colH < oH; ++colH) {
                        for (int colW = 0; colW < oW; ++colW) {
                            for (int c = 0; c < iC; ++c) {                        
                                for (int kRow = 0; kRow < kH; ++kRow) {                        
                                    for (int kCol = 0; kCol < kW; ++kCol) {                            
                        
                                        imRow = (-pH + kRow * dH) + colH*sH;
                                        imCol = (-pW + kCol * dW) + colW*sW;
                                        
                                        col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
                                        im  = imBuff  + b*imStride0  + c*imStride1  + imRow*imStride2 + imCol*imStride3;

                                        if (static_cast<unsigned>(imRow) < static_cast<unsigned>(iH) && static_cast<unsigned>(imCol) < static_cast<unsigned>(iW))
                                            *im += *col;
                                    }
                                }
                            }
                        }                           
                    }
                }  
            }
        }

		op_def static X op(X d1, X *params) {
			return d1;
		}


		/** Calculate buffer offset (like Shape.getOffset) without checking on input for negative indices etc
		*  normally negative indices are bad, OK here because of other checks on input indices
		*  Uses unrolled loop specifically for length 4
		*/
		static _CUDA_HD int getOffsetUnsafe4(int baseOffset, int *shape, int *stride, int *indices) {
			int offset = baseOffset;
			if (shape[0] != 1) offset += indices[0] * stride[0];
			if (shape[1] != 1) offset += indices[1] * stride[1];
			if (shape[2] != 1) offset += indices[2] * stride[2];
			if (shape[3] != 1) offset += indices[3] * stride[3];
			return offset;
		}

		/** A version of Shape.getOffset without checking on input for negative indices etc
		* normally negative indices are bad, OK here because of other checks on input indices
		* Uses unrolled loop specifically for length 6, where indices[2] and indices[3] are zero (always are here)
		*/
		static _CUDA_HD int getOffsetUnsafe6(int baseOffset, int *shape, int *stride, int *indices) {
			int offset = baseOffset;
			if (shape[0] != 1) offset += indices[0] * stride[0];
			if (shape[1] != 1) offset += indices[1] * stride[1];
			if (shape[4] != 1) offset += indices[4] * stride[4];
			if (shape[5] != 1) offset += indices[5] * stride[5];
			return offset;
		}

	};


	template<typename X>
	class Reverse {
	public:
		static const bool requiresSpecial = true;

#ifdef __CUDACC__
		static inline __device__ void execSpecialCuda(X *dx, Nd4jLong *xShapeBuffer, 
                                                    X *result, Nd4jLong *zShapeBuffer, 
                                                    X *extraParams, int *allocationPointer, 
                                                    X *reductionPointer, 
                                                    Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

            __shared__ Nd4jLong xLength;
			__shared__ int xEWS;
            __shared__ char xOrder;
            __shared__ Nd4jLong sLength;
            __shared__ X *shmem;
            int tid = threadIdx.x + blockIdx.x * blockDim.x;

            if (threadIdx.x == 0) {
                xLength = shape::length(xShapeBuffer);
			    xEWS = shape::elementWiseStride(xShapeBuffer);
                xOrder = shape::order(xShapeBuffer);
                sLength = xLength - 1;

                extern __shared__ unsigned char shrd[];
                shmem = (X *) shrd;
            }
            __syncthreads();



            if (dx == result) {

                if (xEWS == 1) {
                    for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) {
                        Nd4jLong idx = sLength - e;
                        X tmp = dx[e];
                        dx[e] = dx[idx];
                        dx[idx] = tmp;
                    }
                } else if (xEWS >= 1) {
                    for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) {
                        Nd4jLong idx1 = (sLength - e) * xEWS;
                        Nd4jLong idx2 =  e * xEWS;
                        X tmp = dx[idx2];
                        dx[idx2] = dx[idx1];
                        dx[idx1] = tmp;
                    }
                } 
                else {                    

					for (int e = tid; e < xLength / 2; e += blockDim.x * gridDim.x) {
                        auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
                        auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength);
                        result[zOffset] = dx[xOffset];
					}
                }

            } else {
                __shared__ int zEWS;
				__shared__ char zOrder;

				if (threadIdx.x == 0) {
				    zEWS = shape::elementWiseStride(zShapeBuffer);
				    zOrder = shape::order(zShapeBuffer);
				}
				__syncthreads();

                if (xEWS == 1 && zEWS == 1 && xOrder == zOrder) {
                    // loop for whole array
                    for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) {
                        result[sLength - e] = dx[e];
                    }
                } else if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) {

                    for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) {
                        result[(sLength - e) * zEWS] = dx[e * xEWS];
                    }
                } 
                else {                  

                    for (int e = tid; e < xLength; e += blockDim.x * gridDim.x) {
                        auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
                        auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength);
                        result[zOffset] = dx[xOffset];
                    }
                }
            }
		}

#endif


		static void execSpecial(X *dx, Nd4jLong *xShapeBuffer, X *result, Nd4jLong *zShapeBuffer, X *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
			Nd4jLong xLength = shape::length(xShapeBuffer);
			int xEWS = shape::elementWiseStride(xShapeBuffer);
            char xOrder = shape::order(xShapeBuffer);
            Nd4jLong sLength = xLength - 1;

			// two step phase here
			if (dx == result) {
				if (xEWS == 1) {
                    PRAGMA_OMP_PARALLEL_FOR_SIMD
                    for (Nd4jLong e = 0; e < xLength / 2; e++) {
                        Nd4jLong idx = sLength - e;
                        auto tmp = dx[e];
                        dx[e] = dx[idx];
                        dx[idx] = tmp;
                    }
				} else if (xEWS > 1) {
                    PRAGMA_OMP_PARALLEL_FOR_SIMD
                    for (Nd4jLong e = 0; e < xLength / 2; e++) {
                        Nd4jLong idx1 = (sLength - e) * xEWS;
                        Nd4jLong idx2 =  e * xEWS;
                        auto tmp = dx[idx2];
                        dx[idx2] = dx[idx1];
                        dx[idx1] = tmp;
                    }
				} 
                else {

                    PRAGMA_OMP_PARALLEL_FOR_SIMD
                    for (Nd4jLong e = 0; e < xLength / 2; e++) {                        
                        auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
                        auto zOffset = shape::getIndexOffset(sLength - e, xShapeBuffer, xLength);

                        result[zOffset] = dx[xOffset];
                    }
				}
			} else {
				// single step phase here
				auto zEWS = shape::elementWiseStride(zShapeBuffer);
				auto zOrder = shape::order(zShapeBuffer);

				if (xEWS == 1 && zEWS == 1 && xOrder == zOrder) {
                    PRAGMA_OMP_PARALLEL_FOR_SIMD
					for (Nd4jLong e = 0; e < xLength; e++) {
						result[sLength - e] = dx[e];
					}
				} else if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) {
                    PRAGMA_OMP_PARALLEL_FOR_SIMD
					for (Nd4jLong e = 0; e < xLength; e++) {
						result[(sLength - e) * zEWS] = dx[e * xEWS];
					}
				} 
                else {

                    PRAGMA_OMP_PARALLEL_FOR_SIMD
					for (Nd4jLong e = 0; e < xLength; e++) {
						auto xOffset = shape::getIndexOffset(e, xShapeBuffer, xLength);
                        auto zOffset = shape::getIndexOffset(sLength - e, zShapeBuffer, xLength);
						result[zOffset] = dx[xOffset];
					}
				}
			}
		}

        op_def static X op(X d1, X *params) {
            return d1;
        }
	};

	template<typename X>
	class SoftMax {
	public:
		static const bool requiresSpecial = true;

#ifdef __CUDACC__
		/**
		*
		*/

		static inline __device__ void execSpecialCuda(
			void *vx, Nd4jLong *xShapeBuffer,
			void *vresult, Nd4jLong *zShapeBuffer,
			void *vextraParams,
			int *allocationPointer, void *reductionPointer, 
            Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<X *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

			auto shape = shape::shapeOf(xShapeBuffer);
			__shared__ X maxResult;
			__shared__ Nd4jLong *maxResultShapeBuffer;

			auto length = shape::length(xShapeBuffer);

			auto stride = shape::stride(xShapeBuffer);
			//compute the row wise maxes

			__shared__ Nd4jLong maxShape[2];

			// it's always 2d here
			__shared__ Nd4jLong tempBuffer[8];

			if (threadIdx.x == 0) {
			    maxResult = (X) 0.0;
			    maxShape[0] = shape[0];
			    maxShape[1] = 1;
                maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape, tempBuffer);
			}
			__syncthreads();


			functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
			__syncthreads();

			//subtract max of each row
			functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
			__syncthreads();

			//after subtracting the row wise maxes take the exp
			functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
			__syncthreads();

			//take the sum for the exponential
			functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
			__syncthreads();

			//divide by the sum
			functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
		}
#endif

		      static void execSpecial(
            void *vx,
            Nd4jLong *xShapeInfo,
            void *vz,
            Nd4jLong *zShapeInfo,
            void *vextraParams,
            Nd4jLong *tadShapeInfo,
            Nd4jLong *tadOffsets) {

            auto x = reinterpret_cast<X *>(vx);
            auto z = reinterpret_cast<X *>(vz);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

            if (shape::isMatrix(xShapeInfo)) {

                if(shape::equalsStrict(xShapeInfo, zShapeInfo)) {
                    if (tadShapeInfo == nullptr) {
                        auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, 1);
                        tadShapeInfo = tadPack.primaryShapeInfo();
                        tadOffsets = tadPack.primaryOffsets();
                    }
                    
                    const uint tadLen    = shape::length(tadShapeInfo);
                    const uint numOfTads = shape::length(xShapeInfo) / tadLen;
        
                    if(shape::elementWiseStride(tadShapeInfo) == 1) {

                        PRAGMA_OMP_PARALLEL_FOR_SIMD
                        for (uint i = 0; i < numOfTads; ++i) {

                            X* inBuff  = x + tadOffsets[i];
                            X* outBuff = z + tadOffsets[i];

                            X max = -nd4j::DataTypeUtils::max<X>();
                            X sum = 0;
                        
                            for(uint j = 0; j < tadLen; ++j)
                                max = nd4j::math::nd4j_max<X>(max, inBuff[j]);            
            
                            for (uint j = 0; j < tadLen; ++j) {
                                X temp = nd4j::math::nd4j_exp<X,X>(inBuff[j] - max);
                                outBuff[j] = temp;
                                sum += temp;
                            }
            
                            for (uint j = 0; j < tadLen; ++j)
                            outBuff[j] /= sum;            
                        }
                    }
                    else {

                        uint xShapeInfoCast[MAX_RANK];
                        bool canCast = nd4j::DataTypeUtils::castShapeInfo(tadShapeInfo, xShapeInfoCast);

                        auto offsets = new Nd4jLong[tadLen];
                        shape::calcOffsets(tadShapeInfo, offsets);

                        PRAGMA_OMP_PARALLEL_FOR_SIMD
                        for (uint i = 0; i < numOfTads; ++i) {                        

                            X* inBuff  = x  + tadOffsets[i];
                            X* outBuff = z + tadOffsets[i];

                            X max = -nd4j::DataTypeUtils::max<X>();
                            X sum = 0.f;                                

                            for(uint j = 0; j < tadLen; ++j)                                 
                                max = nd4j::math::nd4j_max<X>(max, inBuff[offsets[j]]);                            
            
                            for (uint j = 0; j < tadLen; ++j) {
                                X temp = nd4j::math::nd4j_exp<X,X>(inBuff[offsets[j]] - max);
                                outBuff[offsets[j]] = temp;
                                sum += temp;
                            }

                            for (uint j = 0; j < tadLen; ++j)
                                outBuff[offsets[j]] /= sum;
                        }
                        delete []offsets;
                    }
                }
                else {

                    auto shape = shape::shapeOf(xShapeInfo);
                    //iterate along rows
                    int dimension[1] = { 0 };
                    int maxDimension[1] = { 1 };
                    //compute the row wise maxes
                    auto maxResult = new X[shape[0]];
                    for (int i = 0; i < shape[0]; i++)
                        maxResult[i] = 0.0;
                    Nd4jLong maxShape[2] = { shape[0], 1 };
                    auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape);
                    functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Max, x, xShapeInfo, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1,  nullptr, nullptr);

                    //subtract max of each row
                    functions::broadcast::Broadcast<X, X, X>::exec(nd4j::broadcast::Subtract, x, xShapeInfo, maxResult, maxResultShapeBuffer, z, zShapeInfo, dimension, 1, nullptr, nullptr, nullptr, nullptr);

                    //after subtracting the row wise maxes take the exp
                    functions::transform::TransformStrict<X>::exec(nd4j::transform::Exp, z, zShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);

                    //take the sum for the exponential
                    functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Sum, z, zShapeInfo, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);

                    //divide by the sum
                    functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Divide, z, zShapeInfo, maxResult, maxResultShapeBuffer, z, zShapeInfo, dimension, 1, nullptr, nullptr, nullptr, nullptr);

                    delete[] maxResultShapeBuffer;
                    delete[] maxResult;
                }                
            }
            else if (shape::isVector(xShapeInfo)) {
                auto max = -nd4j::DataTypeUtils::max<X>();
                X sum = 0;
                int elementWiseStride = shape::elementWiseStride(xShapeInfo);
                int resultElementWiseStride = shape::elementWiseStride(zShapeInfo);
                int length = shape::length(xShapeInfo);
                if (elementWiseStride >= 1 && resultElementWiseStride >= 1) {
                    if (elementWiseStride == 1 && resultElementWiseStride == 1) {

                        for (int i = 0; i < length; i++) {
                            max = nd4j::math::nd4j_max<X>(max, x[i]);
                        }

                        for (int i = 0; i < length; i++) {
                            z[i] = nd4j::math::nd4j_exp<X,X>(x[i] - max);
                            sum += z[i];
                        }

                        PRAGMA_OMP_SIMD
                        for (int i = 0; i < length; i++) {
                            z[i] /= sum;
                        }
                    }
                    else {

                        for (int i = 0; i < length; i++) {
                            max = nd4j::math::nd4j_max<X>(max, x[i * elementWiseStride]);
                        }

                        for (int i = 0; i < length; i++) {
                            auto r = nd4j::math::nd4j_exp<X, X>(x[i * elementWiseStride] - max);
                            z[i * resultElementWiseStride] = r;
                            sum += r;
                        }

                        for (int i = 0; i < length; i++) {
                            z[i * resultElementWiseStride] /= sum;
                        }
                    }
                }
            }
        }

		op_def static X op(X d1, X *params) {
			return d1;
		}
	};



	template<typename X>
	class LogSoftMax {
	public:
		static const bool requiresSpecial = true;
#ifdef __CUDACC__
		/**
		*
		*/

		static inline __device__ void execSpecialCuda(
            			void *vx, Nd4jLong *xShapeBuffer,
            			void *vresult, Nd4jLong *zShapeBuffer,
            			void *vextraParams,
            			int *allocationPointer, void *reductionPointer, 
                        Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

			auto shape = shape::shapeOf(xShapeBuffer);
			auto stride = shape::stride(xShapeBuffer);
			//iterate along rows

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<X *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

			__shared__ X maxResult;
			__shared__ Nd4jLong *maxResultShapeBuffer;
			if (threadIdx.x == 0) {
				maxResult = (X) 0.0;
			}
			__syncthreads();
			//compute the row wise maxes

			Nd4jLong maxShape[2] = { shape[0], 1 };
			__shared__ Nd4jLong tempBuffer[8];

			if (threadIdx.x == 0)
                maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape, tempBuffer);
			__syncthreads();

			functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
			__syncthreads();

			//subtract max of each row
			functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
			__syncthreads();

			//after subtracting the row wise maxes take the exp
			functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
			__syncthreads();

			//take the sum for the exponential
			functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
			__syncthreads();

			//divide by the sum
			functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
			__syncthreads();

			functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Log, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);

		}
#endif


		static void execSpecial(
			void *vx,
			Nd4jLong *xShapeBuffer,
			void *vresult,
			Nd4jLong *zShapeBuffer,
			void *vextraParams,
			Nd4jLong *tadShapeInfo,
			Nd4jLong *tadOffsets) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<X *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

			if (shape::isMatrix(xShapeBuffer, 2)) {
				auto shape = shape::shapeOf(xShapeBuffer);
				//iterate along rows
				int dimension[1] = { 0 };
				int maxDimension[1] = { 1 };
				//compute the row wise maxes
				auto maxResult = new X[shape[0]];

                PRAGMA_OMP_SIMD
				for (int i = 0; i < shape[0]; i++)
					maxResult[i] = 0.0;

				Nd4jLong maxShape[2] = { shape[0], 1 };
                auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape);
				functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);

				//subtract max of each row
				functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Subtract, dx, xShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);

				//after subtracting the row wise maxes take the exp
				functions::transform::TransformStrict<X>::exec(nd4j::transform::Exp, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets);

				//take the sum for the exponential
				functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);

				//divide by the sum
				functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Divide, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);

				functions::transform::TransformStrict<X>::exec(nd4j::transform::Log, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets);


				delete[] maxResultShapeBuffer;
			}
			else if (shape::isVector(xShapeBuffer, 2)) {
				auto max = -FLOAT_MAX_VALUE;
				X sum = 0;

				auto elementWiseStride = shape::elementWiseStride(xShapeBuffer);
                auto length = shape::length(xShapeBuffer);
				if (elementWiseStride == 1) {

					for (int i = 0; i < length; i++) {
						max = nd4j::math::nd4j_max<X>(max, result[i]);
					}


					for (int i = 0; i < length; i++) {
						result[i] = nd4j::math::nd4j_exp<X, X>(dx[i] - max);
						sum += result[i];
					}

                    PRAGMA_OMP_SIMD
					for (int i = 0; i < length; i++) {
						result[i] /= sum;
						result[i] = nd4j::math::nd4j_log<X, X>(result[i]);
					}
				}
				else if (elementWiseStride > 1) {
					for (int i = 0; i < length; i++) {
						max = nd4j::math::nd4j_max<X>(max, result[i * elementWiseStride]);
					}

					for (int i = 0; i < length; i++) {
						result[i * elementWiseStride] = nd4j::math::nd4j_exp<X, X>(dx[i * elementWiseStride] - max);
						sum += result[i * elementWiseStride];
					}

					for (int i = 0; i < length; i++) {
						result[i * elementWiseStride] /= sum;
						result[i * elementWiseStride] = nd4j::math::nd4j_log<X, X>(result[i * elementWiseStride]);
					}
				}
			}
		}

		op_def static X op(X d1, X *params) {
			return d1;
		}
	};


	/**
	* softmax(x)
	*/
	template<typename X>
	class SoftMaxDerivative {
	public:
		static const bool requiresSpecial = true;

#ifdef __CUDACC__
		/**
		*
		*/

		static inline __device__ void execSpecialCuda(
			                 void *vx, Nd4jLong *xShapeBuffer,
			                 void *vresult, Nd4jLong *zShapeBuffer,
			                 void *vextraParams,
			                 int *allocationPointer, void *reductionPointer, 
                             Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<X *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

			auto shape = shape::shapeOf(xShapeBuffer);
			__shared__ X maxResult;
			__shared__ Nd4jLong *maxResultShapeBuffer;
			__shared__ Nd4jLong resultEWS;

			auto length = shape::length(xShapeBuffer);

			if (threadIdx.x == 0) {
				resultEWS = shape::elementWiseStride(zShapeBuffer);

				maxResult = (X) 0.0;
			}
			__syncthreads();

			auto tride = shape::stride(xShapeBuffer);
			Nd4jLong maxShape[2] = { shape[0], 1 };

			__shared__ Nd4jLong tempBuffer[8];

			if (threadIdx.x == 0)
                maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape, tempBuffer);
			__syncthreads();

			functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
			__syncthreads();

			//subtract max of each row
			functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Subtract, &maxResult, dx, xShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
			__syncthreads();

			//after subtracting the row wise maxes take the exp
			functions::transform::TransformStrictInplace<X>::transformCudaLegacy(nd4j::transform::Exp, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
			__syncthreads();

			//take the sum for the exponential
			functions::reduce::ReduceSameInplace<X>::execScalarCudaLegacy(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, &maxResult, maxResultShapeBuffer, reductionPointer, nullptr);
			__syncthreads();

			//divide by the sum
			functions::scalar::ScalarInplace<X,X,X>::transformCudaLegacy(nd4j::scalar::Divide, &maxResult, result, zShapeBuffer, extraParams, result, zShapeBuffer, allocationPointer);
			__syncthreads();

			if (resultEWS >= 1) {
				for (int i = threadIdx.x; i < length; i += blockDim.x) {
					result[i * resultEWS] = result[i * resultEWS] * ((X) 1.0 - result[i * resultEWS]);
				}
			}
			else {
				printf("Non element wise stride not supported right now\n");
			}

		}
#endif


		static void execSpecial(
			void *vx,
			Nd4jLong *xShapeBuffer,
			void *vresult,
			Nd4jLong *zShapeBuffer,
			void *vextraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<X *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);
            
			if (shape::isMatrix(xShapeBuffer, 2)) {
				auto shape = shape::shapeOf(xShapeBuffer);

				auto resultEleStide = shape::elementWiseStride(zShapeBuffer);

				//iterate along rows
				int dimension[1] = { 0 };
				int maxDimension[1] = { 1 };
				auto len = shape::length(xShapeBuffer);
				//compute the row wise maxes
				auto maxResult = new X[shape[0]];

                PRAGMA_OMP_SIMD
				for (int i = 0; i < shape[0]; i++)
					maxResult[i] = 0.0f;

				Nd4jLong maxShape[2] = { shape[0], 1 };
                auto maxResultShapeBuffer = shape::shapeBuffer(2, nd4j::DataTypeUtils::fromT<X>(), maxShape);
				functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Max, dx, xShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);

				//subtract max of each row
				functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Subtract, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);

				//after subtracting the row wise maxes take the exp
				functions::transform::TransformStrict<X>::exec(nd4j::transform::Exp, result, zShapeBuffer, result, zShapeBuffer, extraParams, tadShapeInfo, tadOffsets);

				//take the sum for the exponential
				functions::reduce::ReduceSameFunction<X>::exec(nd4j::reduce::Sum, result, zShapeBuffer, extraParams, maxResult, maxResultShapeBuffer, maxDimension, 1, nullptr, nullptr);

				//divide by the sum
				functions::broadcast::Broadcast<X,X,X>::exec(nd4j::broadcast::Divide, result, zShapeBuffer, maxResult, maxResultShapeBuffer, result, zShapeBuffer, dimension, 1, nullptr, nullptr, nullptr, nullptr);

				if (resultEleStide >= 1) {
					if (resultEleStide == 1) {
                        PRAGMA_OMP_SIMD
						for (int i = 0; i < len; i++) {
							result[i] = result[i] * (static_cast<X>(1.0f) - result[i]);
						}

					}
					else {
                        PRAGMA_OMP_SIMD
						for (int i = 0; i < len; i++) {
							result[i * resultEleStide] = result[i * resultEleStide] * (static_cast<X>(1.0f) - result[i * resultEleStide]);
						}

					}
				}
				else {
                    
                    for (int i = 0; i < len; i++) {                        
                        Nd4jLong zOffset = shape::getIndexOffset(i, zShapeBuffer, len);
                        result[zOffset] = result[zOffset] * ((X) 1.0f - result[zOffset]);
                    }
                }


				delete[] maxResultShapeBuffer;
				delete[] maxResult;
			}
			else if (shape::isVector(xShapeBuffer, 2)) {
				auto max = -nd4j::DataTypeUtils::max<X>();
				X sum = 0;

				auto elementWiseStride = shape::elementWiseStride(xShapeBuffer);
				auto length = shape::length(xShapeBuffer);
				if (elementWiseStride == 1) {

					for (int i = 0; i < length; i++) {
						max = nd4j::math::nd4j_max<X>(max, result[i]);
					}

					for (int i = 0; i < length; i++) {
						result[i] -= max;
						result[i] = nd4j::math::nd4j_exp<X, X>(result[i]);
						sum += result[i];
					}

					for (int i = 0; i < length; i++) {
						result[i] /= sum;
					}

                    for (int i = 0; i < length; i++) {
                        result[i] = result[i] * ((X) 1.0f - result[i]);
                    }
                } else if (elementWiseStride >= 1) {

					for (int i = 0; i < length; i++) {
						max = nd4j::math::nd4j_max<X>(max, result[i * elementWiseStride]);
					}

					for (int i = 0; i < length; i++) {
						result[i * elementWiseStride] -= max;
						result[i * elementWiseStride] = nd4j::math::nd4j_exp<X, X>(result[i * elementWiseStride]);
						sum += result[i * elementWiseStride];
					}

                    PRAGMA_OMP_SIMD
					for (int i = 0; i < length; i++) {
						result[i * elementWiseStride] /= sum;
					}

                    PRAGMA_OMP_SIMD
					for (int i = 0; i < length; i++) {
						result[i * elementWiseStride] = result[i * elementWiseStride] * ((X) 1.0f - result[i * elementWiseStride]);
					}
				} else {
                    printf("non-ews access on row not implemented yet");
                }
			}
		}

		op_def static X op(X d1, X *params) {
			return d1;
		}
	};


	template<typename X, typename Z>
	class IsMax {
	public:
		static const bool requiresSpecial = true;


#ifdef __CUDACC__

		static inline  __device__ void doAllCuda(
			void *vx,
			Nd4jLong *xShapeBuffer,
			void *vresult,
			Nd4jLong *zShapeBuffer,
			void *vextraParams,
			int *allocationPointer, void *reductionPointer) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<Z *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

// this code is safe to delete, it's never used
/*
			__shared__ int maxIdx;
			__shared__ int length;
			if (threadIdx.x == 0) {
				length = shape::length(zShapeBuffer);
			}
			__syncthreads();

			functions::indexreduce::IndexReduce<T>::template transform<simdOps::IndexMax<T>>(
				dx,
				xShapeBuffer,
				extraParams,
				result,
				zShapeBuffer,
				nullptr,
				1,
				1, allocationPointer, reductionPointer,  nullptr, nullptr);

			__syncthreads();
			if (threadIdx.x == 0)
				maxIdx = (int)result[0];
			__syncthreads();

			for (int i = threadIdx.x; i < length; i += blockDim.x)
				result[i] = 0;
			__syncthreads();

			if (threadIdx.x == 0) {
				result[maxIdx] = 1.0;
			}
			*/
		}
#endif

#ifdef __CUDACC__
		inline __host__

#elif defined(__GNUC__)


#endif
		static void doAll(
			void *vx,
			Nd4jLong *xShapeBuffer,
            void *vresult,
			Nd4jLong *zShapeBuffer,
			void *vextraParams) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<Z *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

			auto length = shape::length(xShapeBuffer);
			auto eleStride = shape::elementWiseStride(xShapeBuffer);
			auto resultEleStride = shape::elementWiseStride(zShapeBuffer);
			auto xOrder = shape::order(xShapeBuffer);
			auto resultOrder = shape::order(zShapeBuffer);

			if (xOrder == resultOrder && xOrder == 'c') {
				if (eleStride == 1 && resultEleStride == 1) {
					if (length < ELEMENT_THRESHOLD) {
						int maxIdx = 0;
                        auto currMax = dx[0];

						for (int i = 0; i < length; i++) {
							if (currMax < dx[i]) {
								currMax = dx[i];
								maxIdx = i;
							}

							result[i] = static_cast<Z>(0);

						}

						result[maxIdx] = static_cast<Z>(1);

					}
					else {
						int maxIdx = 0;
						auto currMax = dx[0];


{
						int maxIdxLocal = maxIdx;
						auto currMaxLocal = currMax;

						for (int i = 0; i < length; i++) {
							if (currMaxLocal < dx[i]) {
								currMaxLocal = dx[i];
								maxIdxLocal = i;
							}
							result[i] = static_cast<Z>(0);
						}

PRAGMA_OMP_CRITICAL
{
						if (currMax < currMaxLocal) {
							currMax = currMaxLocal;
							maxIdx = maxIdxLocal;
						}
}
}
						result[maxIdx] = static_cast<Z>(1);
					}

				}
				else {
					if (length < ELEMENT_THRESHOLD) {
						int maxIdx = 0;
                        auto currMax = dx[0];

						for (int i = 0; i < length; i++) {
							result[i * resultEleStride] = static_cast<Z>(0);
							if (currMax < dx[i * eleStride]) {
								currMax = dx[i * eleStride];
								maxIdx = i;
							}
						}

						result[maxIdx * resultEleStride] = static_cast<Z>(1);

					}
					else {
						int maxIdx = 0;
						auto currMax = dx[0];


{
						int maxIdxLocal = maxIdx;
						auto currMaxLocal = currMax;

						for (int i = 0; i < length; i++) {
							result[i * resultEleStride] = static_cast<Z>(0);
							if (currMaxLocal < dx[i * eleStride]) {
								currMaxLocal = dx[i * eleStride];
								maxIdxLocal = i;
							}
						}

PRAGMA_OMP_CRITICAL
{
						if (currMax < currMaxLocal) {
							currMax = currMaxLocal;
							maxIdx = maxIdxLocal;
						}
}
}
						result[maxIdx * resultEleStride] = static_cast<Z>(1);
					}

				}
			}


			else {
				Nd4jLong shapeIter[MAX_RANK];
				Nd4jLong coord[MAX_RANK];
				int dim;
				Nd4jLong xStridesIter[MAX_RANK];
				Nd4jLong resultStridesIter[MAX_RANK];
				auto xShape = shape::shapeOf(xShapeBuffer);
				auto xStride = shape::stride(xShapeBuffer);
				auto resultStride = shape::stride(zShapeBuffer);
				auto rank = shape::rank(xShapeBuffer);
				auto originalResult = result;
				if (PrepareTwoRawArrayIter<X, Z>(rank,
					xShape,
					dx,
					xStride,
					result,
					resultStride,
					&rank,
					shapeIter,
					&dx,
					xStridesIter,
					&result,
					resultStridesIter) >= 0) {
					auto value = dx[0];
					int idx = 0;
					int maxIdx = 0;
					ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); {
						if (dx[0] > value) {
							value = dx[0];
							maxIdx = idx;
						}

						idx++;
						result[0] = static_cast<Z>(0);

					}
					ND4J_RAW_ITER_TWO_NEXT(
						dim,
						rank,
						coord,
						shapeIter,
						dx,
						xStridesIter,
						result,
						resultStridesIter);

					//pointer to where max value would be
					if (shape::order(zShapeBuffer) == 'c' || (shape::order(zShapeBuffer) == 'f' &&
						maxIdx * shape::stride(zShapeBuffer)[shape::rank(zShapeBuffer) - 1] >=
						shape::length(zShapeBuffer)))
						originalResult[maxIdx] = static_cast<Z>(1);
					else
						originalResult[maxIdx * shape::stride(zShapeBuffer)[shape::rank(zShapeBuffer) - 1]] = static_cast<Z>(1);
				}
			}


		}
	public:


#ifdef __CUDACC__
		/**
		*
		*/

		static inline __device__ void execSpecialCuda(
			             void *vx, Nd4jLong *xShapeBuffer,
			             void *vresult, Nd4jLong *zShapeBuffer,
			             void *vextraParams, int *allocationPointer, 
                         void *reductionPointer, 
                         Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<Z *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

			// FIXME: MAX_DIMENSION is lower then FP16 frame
			if (extraParams == nullptr || (int) extraParams[0] == MAX_DIMENSION) {
				doAllCuda(dx, xShapeBuffer, result, zShapeBuffer, extraParams, allocationPointer, reductionPointer);
			}
		}
#endif

		static void execSpecial(
			void *vx,
			Nd4jLong *xShapeBuffer,
			void *vresult,
			Nd4jLong *zShapeBuffer,
			void *vextraParams,
			Nd4jLong *tadShapeInfo,
			Nd4jLong *tadOffsets) {

            auto dx = reinterpret_cast<X *>(vx);
            auto result = reinterpret_cast<Z *>(vresult);
            auto extraParams = reinterpret_cast<X *>(vextraParams);

			//FIXME: this op should be moved to CustomOps
			if (extraParams == nullptr || (int)extraParams[0] == 0 ||
				((int)extraParams[0] == 1 && (int)extraParams[1] == MAX_DIMENSION)) {
				doAll(dx, xShapeBuffer, result, zShapeBuffer, extraParams);
			}
			else if (shape::isVector(xShapeBuffer)) {
				auto dimensionLength = (int)extraParams[0];
				auto dimension = new int[dimensionLength];
				auto length = shape::length(xShapeBuffer);
				for (int i = 0; i < dimensionLength; i++) {
					dimension[i] = (int)extraParams[i + 1];
				}
				if (shape::shapeOf(xShapeBuffer)[dimension[0]] == 1) {
					for (int i = 0; i < length; i++) {
						result[i] = static_cast<Z>(1);
					}
				}
				else {
					auto eleStride = shape::elementWiseStride(xShapeBuffer);
					if (eleStride == 1) {
						int maxIdx = 0;
						auto currMax = dx[0];
						if (length < ELEMENT_THRESHOLD) {

							for (int i = 0; i < length; i++) {
								if (currMax < dx[i]) {
									currMax = dx[i];
									maxIdx = i;
								}

								result[i] = static_cast<Z>(0);

							}
						}
						else {
PRAGMA_OMP_PARALLEL
{
							int maxIdxLocal = maxIdx;
							auto currMaxLocal = currMax;

							for (int i = 0; i < length; i++) {
								if (currMaxLocal < dx[i]) {
									currMaxLocal = dx[i];
									maxIdxLocal = i;
								}

								result[i] = static_cast<Z>(0);

							}

							PRAGMA_OMP_CRITICAL
                            {
							    if (currMax < currMaxLocal) {
								    currMax = currMaxLocal;
								    maxIdx = maxIdxLocal;
							    }
                            }
}
						}

						result[maxIdx] = static_cast<Z>(1);

					}


					else {
						int maxIdx = 0;
						auto currMax = dx[0];
						if (length < ELEMENT_THRESHOLD) {

							for (int i = 0; i < length; i++) {
								if (currMax < dx[i * eleStride]) {
									currMax = dx[i * eleStride];
									maxIdx = i;
								}

								result[i] = static_cast<Z>(0);
							}
						}
						else {

{
							int maxIdxLocal = maxIdx;
							auto currMaxLocal = currMax;

							for (int i = 0; i < length; i++) {
								if (currMaxLocal < dx[i * eleStride]) {
									currMaxLocal = dx[i * eleStride];
									maxIdxLocal = i;
								}

								result[i] = static_cast<Z>(0);
							}

PRAGMA_OMP_CRITICAL
{
							if (currMax < currMaxLocal) {
								currMax = currMaxLocal;
								maxIdx = maxIdxLocal;
							}
}
}
						}

						result[maxIdx] = static_cast<Z>(1);
					}
				}


			}
			else {
                auto dimensionLength = (int) extraParams[0];
                auto dimension = new int[dimensionLength];

                PRAGMA_OMP_SIMD
                for (int i = 0; i < dimensionLength; i++) {
                    dimension[i] = (int) extraParams[i + 1];
                }
                //decompose in to several sub tads after
                //moving all dimensions (in sorted order)
                //to the back.
                //permuted version of the x shape info for setting up the tad problem				
				auto tadShapeShapeInfo = tadShapeInfo;
				if(tadShapeInfo==nullptr) {
                    auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeBuffer, dimension, dimensionLength);

					tadShapeShapeInfo = tadPack.primaryShapeInfo();
					tadOffsets = tadPack.primaryOffsets();
                    tadShapeInfo = tadShapeShapeInfo;
				}						                                				

                auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeBuffer, dimension, dimensionLength);
                auto tads = shape::length(xShapeBuffer) / tadLength;

                int tadsPerThread = tads / TAD_THRESHOLD;
                int num_threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
                num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads());

                auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo);
                auto zEWS = tadEWS;

                int span = (tads / num_threads) + 8;

                PRAGMA_OMP_PARALLEL_THREADS(num_threads)
                {
                    int tid = omp_get_thread_num();
                    int start = span * tid;
                    int end = span * (tid + 1);
                    if (end > tads) end = tads;

                    for (int r = start; r < end; r++) {
                        if (tadEWS > 0 && zEWS > 0 && dimensionLength == 1) {
                            auto rX = dx + tadOffsets[r];
                            auto rZ = result + tadOffsets[r];

                            auto maxValue = rX[0];
                            int maxIdx = 0;
                            if (tadEWS == 1 && zEWS == 1) {

                                for (int i = 0; i < tadLength; i++) {
                                    if (rX[i] > maxValue) {
                                        maxIdx = i;
                                        maxValue = rX[i];
                                    }
                                }


                                for (int i = 0; i < tadLength; i++) {
                                    rZ[i] = static_cast<Z>(maxIdx == i);
                                }

                            } else {

                                for (int i = 0; i < tadLength; i++) {
                                    if (rX[i * tadEWS] > maxValue) {
                                        maxIdx = i;
                                        maxValue = rX[i * tadEWS];
                                    }
                                }

                                for (int i = 0; i < tadLength; i++) {
                                    rZ[i * zEWS] = static_cast<Z>(maxIdx == i);
                                }
                            }
                        } else {
                            int tadsPerThread = tads / TAD_THRESHOLD;
                            int num_threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
                            num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads());

                            auto offset = tadOffsets[r];
                            Nd4jLong shapeIter[MAX_RANK];
                            Nd4jLong coord[MAX_RANK];
                            int dim;
                            Nd4jLong xStridesIter[MAX_RANK];
                            Nd4jLong resultStridesIter[MAX_RANK];
                            auto xShape = shape::shapeOf(tadShapeShapeInfo);
                            auto xStride = shape::stride(tadShapeShapeInfo);
                            auto resultStride = shape::stride(tadShapeShapeInfo);
                            int rank = shape::rank(tadShapeShapeInfo);
                            auto xPointer = dx + offset;
                            auto resultPointer = result + offset;
                            auto maxValue = xPointer[0];

                            auto maxCursor = resultPointer;
                            Nd4jPointer maxCursorLong = reinterpret_cast<Nd4jPointer>(maxCursor);
                            if (PrepareTwoRawArrayIter<X,Z>(rank,
                                                             xShape,
                                                             xPointer,
                                                             xStride,
                                                             resultPointer,
                                                             resultStride,
                                                             &rank,
                                                             shapeIter,
                                                             &xPointer,
                                                             xStridesIter,
                                                             &resultPointer,
                                                             resultStridesIter) >= 0) {
                                   ND4J_RAW_ITER_START(dim, rank, coord, shapeIter); {
                                       if (maxValue < xPointer[0]) {
                                           maxCursor = resultPointer;
                                           maxCursorLong = reinterpret_cast<Nd4jPointer>(resultPointer);
                                           maxValue = xPointer[0];
                                       }

                                       resultPointer[0] = static_cast<Z>(0);
                                   }
                                   ND4J_RAW_ITER_TWO_NEXT(dim,
                                                          rank,
                                                          coord,
                                                          shapeIter,
                                                          xPointer,
                                                          xStridesIter,
                                                          resultPointer,
                                                          resultStridesIter);
                                   maxCursor = reinterpret_cast<Z *>(maxCursorLong);
                                   maxCursor[0] = static_cast<Z>(1);;
                            }
                        }
                    }
                }

                delete[] dimension;
            }
		}

		op_def static Z op(X d1, X *params) {
			return nd4j::math::softplus<X,Z>(d1);
		}
	};
}