/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #ifndef LIBND4J_SPECIAL_RANDOM_OPS_H #define LIBND4J_SPECIAL_RANDOM_OPS_H #include #include #include #include #include namespace randomOps { ////////////////////////////////////////////////////////////////////// template class Choice { public: method_idx method_X method_XY static const bool requiresSpecial = true; #ifdef __CUDACC__ __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { /** * X holds data, * Y holds probabilities * Z will hold results */ // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0 //T probSum = extraArguments[0]; __shared__ Nd4jLong xLength; __shared__ Nd4jLong yLength; __shared__ Nd4jLong zLength; __shared__ Nd4jLong xEWS; __shared__ Nd4jLong yEWS; __shared__ Nd4jLong zEWS; __shared__ char xOrder; __shared__ char yOrder; __shared__ char zOrder; __shared__ nd4j::graph::RandomGenerator *rng; __shared__ unsigned char *cB; __shared__ unsigned char *dB; __shared__ nd4j::graph::RandomGenerator *devRng; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; rng = (nd4j::graph::RandomGenerator*) shmem; cB = shmem; devRng = reinterpret_cast (state); dB = reinterpret_cast (state); xLength = shape::length(xShapeBuffer); yLength = shape::length(yShapeBuffer); zLength = shape::length(zShapeBuffer); xEWS = shape::elementWiseStride(xShapeBuffer); yEWS = shape::elementWiseStride(yShapeBuffer); zEWS = shape::elementWiseStride(zShapeBuffer); xOrder = shape::order(xShapeBuffer); yOrder = shape::order(yShapeBuffer); zOrder = shape::order(zShapeBuffer); } __syncthreads(); // using this loop instead of memcpy for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1 && xOrder == yOrder && xOrder == zOrder) { for (Nd4jLong e = tid; e < zLength; e+=blockDim.x * gridDim.x) { T prob = rng->relativeT(e); T cumProb = (T) 0.0f; for (Nd4jLong f = 0; f < yLength; f++) { T relProb = y[f * yEWS]; cumProb += relProb; if (prob <= cumProb || f == yLength - 1) { z[e * zEWS] = x[f * xEWS]; f += yLength; } // __syncthreads(); // Eliminated due RTX20xx specific } // __syncthreads(); // Eliminated due RTX20xx specific } } else { for (Nd4jLong i = tid; i < zLength; i+=blockDim.x * gridDim.x) { auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); T prob = rng->relativeT(i); T cumProb = (T) 0.0f; for (Nd4jLong f = 0; f < yLength; f++) { auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); T relProb = y[yOffset2]; cumProb += relProb; if (prob <= cumProb || f == yLength - 1) { auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); z[zOffset2] = x[xOffset2]; f += yLength; } // __syncthreads(); // Eliminated due RTX20xx specific } // __syncthreads(); // Eliminated due RTX20xx specific } } } #endif static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { /** * X holds data, * Y holds probabilities * Z will hold results */ //nd4j::random::RandomBuffer *buffer = reinterpret_cast (state); nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0 //T probSum = extraArguments[0]; auto xLength = shape::length(xShapeBuffer); auto yLength = shape::length(yShapeBuffer); auto zLength = shape::length(zShapeBuffer); auto xEWS = shape::elementWiseStride(xShapeBuffer); auto yEWS = shape::elementWiseStride(yShapeBuffer); auto zEWS = shape::elementWiseStride(zShapeBuffer); int elementsPerThread = zLength / TAD_THRESHOLD; int _threads = nd4j::math::nd4j_max(1, elementsPerThread); _threads = nd4j::math::nd4j_min(_threads, nd4j::Environment::getInstance()->maxThreads()); if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1) { auto func = PRAGMA_THREADS_FOR { for (uint64_t e = start; e < stop; e += increment) { T prob = rng->relativeT(e); T cumProb = (T) 0.0f; for (Nd4jLong f = 0; f < yLength; f++) { T relProb = y[f * yEWS]; cumProb += relProb; if (prob <= cumProb || f == yLength - 1) { z[e * zEWS] = x[f * xEWS]; break; } } } }; samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); } else { auto func = PRAGMA_THREADS_FOR { for (Nd4jLong i = 0; i < zLength; i++) { auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); T prob = rng->relativeT(i); T cumProb = (T) 0.0f; for (Nd4jLong f = 0; f < yLength; f++) { auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); T relProb = y[yOffset2]; cumProb += relProb; if (prob <= cumProb || f == yLength - 1) { auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); z[zOffset2] = x[xOffset2]; break; } } } }; samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); } } }; ////////////////////////////////////////////////////////////////////// /** * This Op produces random values within specified boundaries. Distribuion is Gaussian */ template class GaussianDistribution { public: method_XY method_X method_idx static const bool requiresSpecial = true; #ifdef __CUDACC__ __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { __shared__ T epsilon; __shared__ T two_pi; __shared__ Nd4jLong zLength; __shared__ Nd4jLong zEWS; __shared__ Nd4jLong yEWS; __shared__ T mean; __shared__ T stddev; __shared__ int step; __shared__ T *tZ; __shared__ nd4j::graph::RandomGenerator* rng; __shared__ unsigned char *cB; __shared__ unsigned char *dB; __shared__ nd4j::graph::RandomGenerator *devRng; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; rng = reinterpret_cast(shmem); cB = shmem; devRng = reinterpret_cast (state); dB = reinterpret_cast (state); tZ = reinterpret_cast(shmem + sizeof(nd4j::graph::RandomGenerator)); zLength = shape::length(zShapeBuffer); zEWS = shape::elementWiseStride(zShapeBuffer); yEWS = shape::elementWiseStride(yShapeBuffer); epsilon = static_cast(1e-5); two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); mean = extraArguments[0]; stddev = extraArguments[1]; step = (blockDim.x * gridDim.x); } __syncthreads(); // using this loop instead of memcpy for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; T t(-2.0f); for (int e = tid; e < middle; e += step) { auto epm = e + middle; // we need to get random values T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); T realMean0 = y == z ? mean : y[e * yEWS]; z[e * zEWS] = (nd4j::math::nd4j_sqrt(t * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; if (epm < zLength) { T realMean1 = y == z ? mean : y[epm * yEWS]; z[epm * zEWS] = (nd4j::math::nd4j_sqrt(t * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; } } } #endif static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); auto zLength = shape::length(zShapeBuffer); auto yEWS = shape::elementWiseStride(yShapeBuffer); auto zEWS = shape::elementWiseStride(zShapeBuffer); auto middle = zLength % 2 + zLength / 2; int elementsPerThread = middle / TAD_THRESHOLD; int _threads = nd4j::math::nd4j_max(1, elementsPerThread); _threads = nd4j::math::nd4j_min(_threads, nd4j::Environment::getInstance()->maxThreads()); int span = (middle / _threads) + 8; // we're enforcing even chunks, since it's mandatory for this algorithm span -= span % 2; //nd4j::random::RandomBuffer *buffer = reinterpret_cast (state); nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); const T mean = extraArguments[0]; const T stddev = extraArguments[1]; const T epsilon = static_cast(1e-5); auto func = PRAGMA_THREADS_FOR { for (uint64_t e = start; e < stop; e += increment) { auto epm = e + middle; // we need to get random values T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); T realMean0 = y == z ? mean : y[e * yEWS]; auto z0 = (nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; z[e * zEWS] = z0; if (epm < zLength) { T realMean1 = y == z ? mean : y[epm * yEWS]; auto z1 = (nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; z[epm * zEWS] = z1; } } }; samediff::Threads::parallel_for(func, 0, middle, 1, _threads); } }; ////////////////////////////////////////////////////////////////////// /** * This Op produces random values within [0..N], Distribuion is binomial */ template class BinomialDistribution { public: method_XY method_X method_idx static const bool requiresSpecial = true; #ifdef __CUDACC__ __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { int trials = (int) extraArguments[0]; T prob = extraArguments[1]; __shared__ Nd4jLong zLength; __shared__ int yEWS; __shared__ int zEWS; __shared__ nd4j::graph::RandomGenerator* rng; __shared__ unsigned char *cB; __shared__ unsigned char *dB; __shared__ nd4j::graph::RandomGenerator *devRng; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; rng = reinterpret_cast(shmem); cB = shmem; devRng = reinterpret_cast(state); dB = reinterpret_cast (state); zLength = shape::length(zShapeBuffer); yEWS = shape::elementWiseStride(yShapeBuffer); zEWS = shape::elementWiseStride(zShapeBuffer); } __syncthreads(); // using this loop instead of memcpy for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { int success = 0; for (int t = 1; t <= trials; t++) { T randVal = rng->relativeT((e+1) * t); if (y != z) { // we're using external probs prob = y[(t-1) * yEWS]; } if (randVal < prob) success++; } // if trials is set to 0, effectively we just have successful memset z[e * zEWS] = static_cast(success); } } #endif static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { int trials = (int) extraArguments[0]; Nd4jLong zLength = shape::length(zShapeBuffer); auto yEWS = shape::elementWiseStride(yShapeBuffer); auto zEWS = shape::elementWiseStride(zShapeBuffer); int elementsPerThread = zLength / TAD_THRESHOLD; int _threads = nd4j::math::nd4j_max(1, elementsPerThread); _threads = nd4j::math::nd4j_min(_threads, nd4j::Environment::getInstance()->maxThreads()); T prob = extraArguments[1]; nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); auto func = PRAGMA_THREADS_FOR { for (Nd4jLong e = start; e < stop; e += increment) { int success = 0; for (int t = 1; t <= trials; t++) { T randVal = rng->relativeT((e+1) * t); if (y != z) { // we're using external probs prob = y[(t-1) * yEWS]; } if (randVal < prob) success++; } // if trials is set to 0, effectively we just have successful memset z[e * zEWS] = static_cast(success); } }; samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); } }; ////////////////////////////////////////////////////////////////////// /** * This Op produces random values within [0..N], Distribuion is binomial */ template class BinomialDistributionEx { public: method_XY method_X method_idx static const bool requiresSpecial = true; #ifdef __CUDACC__ __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { int trials = (int) extraArguments[0]; T prob = extraArguments[1]; __shared__ Nd4jLong zLength; __shared__ int yEWS; __shared__ int zEWS; __shared__ nd4j::graph::RandomGenerator* rng; __shared__ unsigned char *cB; __shared__ unsigned char *dB; __shared__ nd4j::graph::RandomGenerator *devRng; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; rng = (nd4j::graph::RandomGenerator*) shmem; cB = shmem; devRng = reinterpret_cast (state); dB = reinterpret_cast (state); zLength = shape::length(zShapeBuffer); yEWS = shape::elementWiseStride(yShapeBuffer); zEWS = shape::elementWiseStride(zShapeBuffer); } __syncthreads(); // using this loop instead of memcpy for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { int success = 0; for (int t = 1; t <= trials; t++) { T randVal = rng->relativeT((e+1) * t); if (y != z) { // we're using external probs prob = y[e * yEWS]; } if (randVal < prob) success++; } // if trials is set to 0, effectively we just have successful memset z[e * zEWS] = (T) success; } } #endif static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { int trials = (int) extraArguments[0]; Nd4jLong zLength = shape::length(zShapeBuffer); auto yEWS = shape::elementWiseStride(yShapeBuffer); auto zEWS = shape::elementWiseStride(zShapeBuffer); int elementsPerThread = zLength / TAD_THRESHOLD; int _threads = nd4j::math::nd4j_max(1, elementsPerThread); _threads = nd4j::math::nd4j_min(_threads, nd4j::Environment::getInstance()->maxThreads()); T prob = extraArguments[1]; //nd4j::random::RandomBuffer *buffer = reinterpret_cast (state); nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); auto func = PRAGMA_THREADS_FOR { for (uint64_t e = start; e < stop; e += increment) { int success = 0; for (int t = 1; t <= trials; t++) { T randVal = rng->relativeT((e+1) * t); if (y != z) { // we're using external probs prob = y[e * yEWS]; } if (randVal < prob) success++; } // if trials is set to 0, effectively we just have successful memset z[e * zEWS] = static_cast(success); } }; samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); } }; ////////////////////////////////////////////////////////////////////// // This Op produces random Gaussian values within [mean-2*stddev,mean+2*stddev] template class TruncatedNormalDistribution { private: static inline _CUDA_HD T step(nd4j::graph::RandomGenerator* rng, T mean, T stddev, Nd4jLong e, Nd4jLong middle, T& z) { auto epm = e + middle; const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); const T epsilon = static_cast(1.e-5f); // we need to get random values T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); T realMean0 = mean; auto z0 = (nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; z = z0; if (epm < middle) { T realMean1 = mean; auto z1 = (nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; z = z1; } return z; } public: method_XY method_X method_idx static const bool requiresSpecial = true; #ifdef __CUDACC__ __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { __shared__ T epsilon; __shared__ T two_pi; __shared__ Nd4jLong zLength; __shared__ Nd4jLong zEWS; __shared__ Nd4jLong yEWS; __shared__ T mean; __shared__ T stddev; __shared__ int step; __shared__ T *tZ; __shared__ nd4j::graph::RandomGenerator* rng; __shared__ unsigned char *cB; __shared__ unsigned char *dB; __shared__ nd4j::graph::RandomGenerator* devRng; __shared__ Nd4jLong middle; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; rng = reinterpret_cast(shmem); cB = shmem; devRng = reinterpret_cast (state); dB = reinterpret_cast (state); tZ = reinterpret_cast(shmem + sizeof(nd4j::graph::RandomGenerator)); zLength = shape::length(zShapeBuffer); zEWS = shape::elementWiseStride(zShapeBuffer); yEWS = shape::elementWiseStride(yShapeBuffer); epsilon = static_cast(1e-6f); two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); mean = extraArguments[0]; stddev = extraArguments[1]; step = (blockDim.x * gridDim.x); middle = zLength / 2 + (zLength % 2); } __syncthreads(); // using this loop instead of memcpy for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; GaussianDistribution::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); __syncthreads(); T ds = nd4j::math::nd4j_abs(stddev) * static_cast(2.0f); for (Nd4jLong e = tid; e < zLength; e += step) { if (z[e] > mean + ds || z[e] < mean - ds) { z[e] = TruncatedNormalDistribution::step(rng, mean, stddev, e, middle, z[e]); if (z[e] > mean + ds || z[e] < mean - ds) z[e] = mean + nd4j::DataTypeUtils::min(); } } } #endif static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { GaussianDistribution::specialOp(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); Nd4jLong zLength = shape::length(zShapeBuffer); //auto yEWS = shape::elementWiseStride(yShapeBuffer); //auto zEWS = shape::elementWiseStride(zShapeBuffer); nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); T mean = extraArguments[0]; T stddev = extraArguments[1]; T ds = nd4j::math::nd4j_abs(stddev) * (T) 2.0f; Nd4jLong middle = zLength / 2 + (zLength % 2); int elementsPerThread = middle / TAD_THRESHOLD; int _threads = nd4j::math::nd4j_max(1, elementsPerThread); _threads = nd4j::math::nd4j_min(_threads, nd4j::Environment::getInstance()->maxThreads()); const T epsilon = static_cast(1e-5); auto func = PRAGMA_THREADS_FOR { for (uint64_t e = start; e < stop; e += increment) { if (z[e] > mean + ds || z[e] < mean - ds) { z[e] = step(rng, mean, stddev, e, middle, z[e]); if (z[e] > mean + ds || z[e] < mean - ds) z[e] = mean + nd4j::DataTypeUtils::min(); } } }; samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); } }; ////////////////////////////////////////////////////////////////////// // This Op produces random Log-normal distribution template class LogNormalDistribution { public: method_XY method_X method_idx static const bool requiresSpecial = true; #ifdef __CUDACC__ __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { __shared__ T epsilon; __shared__ T two_pi; __shared__ Nd4jLong zLength; __shared__ Nd4jLong zEWS; __shared__ Nd4jLong yEWS; __shared__ T mean; __shared__ T stddev; __shared__ int step; __shared__ T *tZ; __shared__ nd4j::graph::RandomGenerator* rng; __shared__ unsigned char *cB; __shared__ unsigned char *dB; __shared__ nd4j::graph::RandomGenerator* devRng; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; rng = reinterpret_cast(state); cB = shmem; devRng = reinterpret_cast(state); dB = reinterpret_cast (state); tZ = reinterpret_cast(shmem + sizeof(nd4j::graph::RandomGenerator)); zLength = shape::length(zShapeBuffer); zEWS = shape::elementWiseStride(zShapeBuffer); yEWS = shape::elementWiseStride(yShapeBuffer); epsilon = static_cast(1e-5); two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); mean = extraArguments[0]; stddev = extraArguments[1]; step = (blockDim.x * gridDim.x); } __syncthreads(); // using this loop instead of memcpy for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; for (Nd4jLong e = tid; e < middle; e += step) { auto epm = e + middle; // we need to get random values T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); T realMean = y == z ? mean : y[e * yEWS]; z[e *zEWS] = nd4j::math::nd4j_exp((nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_cos(two_pi * r1)) * stddev + realMean); if (epm < zLength) { realMean = y == z ? mean : y[epm * yEWS]; z[epm *zEWS] = nd4j::math::nd4j_exp((nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_sin(two_pi * r1)) * stddev + realMean); } } } #endif static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); Nd4jLong zLength = shape::length(zShapeBuffer); auto yEWS = shape::elementWiseStride(yShapeBuffer); auto zEWS = shape::elementWiseStride(zShapeBuffer); auto middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; int elementsPerThread = middle / TAD_THRESHOLD; int _threads = nd4j::math::nd4j_max(1, elementsPerThread); _threads = nd4j::math::nd4j_min(_threads, nd4j::Environment::getInstance()->maxThreads()); int span = (zLength / _threads) + 8; // we're enforcing even chunks, since it's mandatory for this algorithm span -= span % 2; // auto buffer = reinterpret_cast (state); nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); const T mean = extraArguments[0]; const T stddev = extraArguments[1]; const T epsilon = static_cast(1e-5); auto func = PRAGMA_THREADS_FOR { PRAGMA_OMP_SIMD for (uint64_t e = start; e < stop; e += increment) { auto epm = e + middle; // we need to get random values T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); T realMean = y == z ? mean : y[e * yEWS]; z[e * zEWS] = nd4j::math::nd4j_exp((nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_cos(two_pi * r1)) * stddev + realMean); if (epm < zLength) { realMean = y == z ? mean : y[epm * yEWS]; z[epm * zEWS] = nd4j::math::nd4j_exp((nd4j::math::nd4j_sqrt(static_cast(-2.0f) * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_sin(two_pi * r1)) * stddev + realMean); } } }; samediff::Threads::parallel_for(func, 0, middle, 1, _threads); } }; } #endif //LIBND4J_SPECIAL_RANDOM_OPS_H