| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // @author raver119@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifndef LIBND4J_SPECIAL_RANDOM_OPS_H
 | 
					
						
							|  |  |  | #define LIBND4J_SPECIAL_RANDOM_OPS_H
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <ops/random_ops.h>
 | 
					
						
							|  |  |  | #include <helpers/shape.h>
 | 
					
						
							|  |  |  | #include <graph/RandomGenerator.h>
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <ops/specials_cuda.h>
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | #include <execution/Threads.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace randomOps { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     template<typename T> | 
					
						
							|  |  |  |     class Choice { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         method_idx | 
					
						
							|  |  |  |         method_X | 
					
						
							|  |  |  |         method_XY | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static const bool requiresSpecial = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef __CUDACC__
 | 
					
						
							|  |  |  |         __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             /**
 | 
					
						
							|  |  |  |              * X holds data, | 
					
						
							|  |  |  |              * Y holds probabilities | 
					
						
							|  |  |  |              * Z will hold results | 
					
						
							|  |  |  |              */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0
 | 
					
						
							|  |  |  |             //T probSum = extraArguments[0];
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ Nd4jLong xLength; | 
					
						
							|  |  |  |             __shared__ Nd4jLong yLength; | 
					
						
							|  |  |  |             __shared__ Nd4jLong zLength; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ Nd4jLong xEWS; | 
					
						
							|  |  |  |             __shared__ Nd4jLong yEWS; | 
					
						
							|  |  |  |             __shared__ Nd4jLong zEWS; | 
					
						
							|  |  |  |             __shared__ char xOrder; | 
					
						
							|  |  |  |             __shared__ char yOrder; | 
					
						
							|  |  |  |             __shared__ char zOrder; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator *rng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             __shared__ unsigned char *cB; | 
					
						
							|  |  |  |             __shared__ unsigned char *dB; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator *devRng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if (threadIdx.x == 0) { | 
					
						
							|  |  |  |                 extern __shared__ unsigned char shmem[]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 rng = (sd::graph::RandomGenerator*) shmem; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB = shmem; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 devRng = reinterpret_cast<sd::graph::RandomGenerator*> (state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 dB = reinterpret_cast<unsigned char *> (state); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 xLength = shape::length(xShapeBuffer); | 
					
						
							|  |  |  |                 yLength = shape::length(yShapeBuffer); | 
					
						
							|  |  |  |                 zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 xEWS = shape::elementWiseStride(xShapeBuffer); | 
					
						
							|  |  |  |                 yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |                 zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  |                 xOrder = shape::order(xShapeBuffer); | 
					
						
							|  |  |  |                 yOrder = shape::order(yShapeBuffer); | 
					
						
							|  |  |  |                 zOrder = shape::order(zShapeBuffer); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // using this loop instead of memcpy
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB[e] = dB[e]; | 
					
						
							| 
									
										
										
										
											2019-07-16 18:48:40 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int tid = blockIdx.x * blockDim.x + threadIdx.x; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1 && xOrder == yOrder && xOrder == zOrder) { | 
					
						
							|  |  |  |                 for (Nd4jLong e = tid; e < zLength; e+=blockDim.x * gridDim.x) { | 
					
						
							|  |  |  |                     T prob = rng->relativeT<T>(e); | 
					
						
							|  |  |  |                     T cumProb = (T) 0.0f; | 
					
						
							|  |  |  |                     for (Nd4jLong f = 0; f < yLength; f++) { | 
					
						
							|  |  |  |                         T relProb = y[f * yEWS]; | 
					
						
							|  |  |  |                         cumProb += relProb; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if (prob <= cumProb || f == yLength - 1) { | 
					
						
							|  |  |  |                             z[e * zEWS] = x[f * xEWS]; | 
					
						
							|  |  |  |                             f += yLength; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  | //                        __syncthreads();  // Eliminated due RTX20xx specific
 | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  | //                    __syncthreads();  // Eliminated due RTX20xx specific
 | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             else { | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 for (Nd4jLong i = tid; i < zLength; i+=blockDim.x * gridDim.x) { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  |                     auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                     T prob = rng->relativeT<T>(i); | 
					
						
							|  |  |  |                     T cumProb = (T) 0.0f; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     for (Nd4jLong f = 0; f < yLength; f++) { | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                         auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                         T relProb = y[yOffset2]; | 
					
						
							|  |  |  |                         cumProb += relProb; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  |                         if (prob <= cumProb || f == yLength - 1) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                             auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                             z[zOffset2] = x[xOffset2]; | 
					
						
							|  |  |  |                             f += yLength; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  | //                        __syncthreads();  // Eliminated due RTX20xx specific
 | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  | //                    __syncthreads();  // Eliminated due RTX20xx specific
 | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             /**
 | 
					
						
							|  |  |  |              * X holds data, | 
					
						
							|  |  |  |              * Y holds probabilities | 
					
						
							|  |  |  |              * Z will hold results | 
					
						
							|  |  |  |              */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             //sd::random::RandomBuffer *buffer = reinterpret_cast<sd::random::RandomBuffer *> (state);
 | 
					
						
							|  |  |  |             sd::graph::RandomGenerator* rng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0
 | 
					
						
							|  |  |  |             //T probSum = extraArguments[0];
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             auto xLength = shape::length(xShapeBuffer); | 
					
						
							|  |  |  |             auto yLength = shape::length(yShapeBuffer); | 
					
						
							|  |  |  |             auto zLength = shape::length(zShapeBuffer); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             auto xEWS = shape::elementWiseStride(xShapeBuffer); | 
					
						
							|  |  |  |             auto yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |             auto zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int elementsPerThread = zLength / TAD_THRESHOLD; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             int _threads = sd::math::nd4j_max<int>(1, elementsPerThread); | 
					
						
							|  |  |  |             _threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance()->maxThreads()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1) { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 auto func = PRAGMA_THREADS_FOR { | 
					
						
							| 
									
										
										
										
											2020-02-20 11:43:26 +03:00
										 |  |  |                     for (auto e = start; e < stop; e++) { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                         T prob = rng->relativeT<T>(e); | 
					
						
							|  |  |  |                         T cumProb = (T) 0.0f; | 
					
						
							|  |  |  |                         for (Nd4jLong f = 0; f < yLength; f++) { | 
					
						
							|  |  |  |                             T relProb = y[f * yEWS]; | 
					
						
							|  |  |  |                             cumProb += relProb; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                             if (prob <= cumProb || f == yLength - 1) { | 
					
						
							|  |  |  |                                 z[e * zEWS] = x[f * xEWS]; | 
					
						
							|  |  |  |                                 break; | 
					
						
							|  |  |  |                             } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                         } | 
					
						
							|  |  |  |                     } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             else { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 auto func = PRAGMA_THREADS_FOR { | 
					
						
							|  |  |  |                     for (Nd4jLong i = 0; i < zLength; i++) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                         auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); | 
					
						
							|  |  |  |                         T prob = rng->relativeT<T>(i); | 
					
						
							|  |  |  |                         T cumProb = (T) 0.0f; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                         for (Nd4jLong f = 0; f < yLength; f++) { | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                             auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); | 
					
						
							|  |  |  |                             T relProb = y[yOffset2]; | 
					
						
							|  |  |  |                             cumProb += relProb; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                             if (prob <= cumProb || f == yLength - 1) { | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                                 auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); | 
					
						
							|  |  |  |                                 z[zOffset2] = x[xOffset2]; | 
					
						
							|  |  |  |                                 break; | 
					
						
							|  |  |  |                             } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                         } | 
					
						
							|  |  |  |                     } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     /**
 | 
					
						
							|  |  |  |     * This Op produces random values within specified boundaries. Distribuion is Gaussian | 
					
						
							|  |  |  |     */ | 
					
						
							|  |  |  |     template<typename T> | 
					
						
							|  |  |  |     class GaussianDistribution { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         method_XY | 
					
						
							|  |  |  |         method_X | 
					
						
							|  |  |  |         method_idx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static const bool requiresSpecial = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef __CUDACC__
 | 
					
						
							|  |  |  |         __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ T epsilon; | 
					
						
							|  |  |  |             __shared__ T two_pi; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ Nd4jLong zLength; | 
					
						
							|  |  |  |             __shared__ Nd4jLong zEWS; | 
					
						
							|  |  |  |             __shared__ Nd4jLong yEWS; | 
					
						
							|  |  |  |             __shared__ T mean; | 
					
						
							|  |  |  |             __shared__ T stddev; | 
					
						
							|  |  |  |             __shared__ int step; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ T *tZ; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator* rng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             __shared__ unsigned char *cB; | 
					
						
							|  |  |  |             __shared__ unsigned char *dB; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator *devRng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if (threadIdx.x == 0) { | 
					
						
							|  |  |  |                 extern __shared__ unsigned char shmem[]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 rng = reinterpret_cast<sd::graph::RandomGenerator*>(shmem); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB = shmem; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 devRng = reinterpret_cast<sd::graph::RandomGenerator *> (state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 dB = reinterpret_cast<unsigned char *> (state); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 tZ = reinterpret_cast<T *>(shmem + sizeof(sd::graph::RandomGenerator)); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |                 zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  |                 yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 epsilon = static_cast<T>(1e-5); | 
					
						
							|  |  |  |                 two_pi = static_cast<T>(2.0f) * static_cast<T>(3.14159265358979323846); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 mean = extraArguments[0]; | 
					
						
							|  |  |  |                 stddev = extraArguments[1]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 step = (blockDim.x * gridDim.x); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // using this loop instead of memcpy
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB[e] = dB[e]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-16 18:48:40 +03:00
										 |  |  |             __syncthreads(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int tid = blockIdx.x * blockDim.x + threadIdx.x; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; | 
					
						
							|  |  |  |             T t(-2.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (int e = tid; e < middle; e += step) { | 
					
						
							|  |  |  |                 auto epm = e + middle; | 
					
						
							|  |  |  |                 // we need to get random values
 | 
					
						
							|  |  |  |                 T r0 = rng->relativeT<T>(e, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  |                 T r1 = rng->relativeT<T>(epm, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 T realMean0 = y == z ? mean : y[e * yEWS]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 z[e * zEWS] = (sd::math::nd4j_sqrt<T,T>(t * sd::math::nd4j_log<T,T>(r0)) * sd::math::nd4j_cos<T,T>(two_pi * r1)) * stddev + realMean0; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 if (epm < zLength) { | 
					
						
							|  |  |  |                     T realMean1 = y == z ? mean : y[epm * yEWS]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                     z[epm * zEWS] =  (sd::math::nd4j_sqrt<T,T>(t * sd::math::nd4j_log<T,T>(r0)) * sd::math::nd4j_sin<T,T>(two_pi * r1)) * stddev + realMean1; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static inline void | 
					
						
							|  |  |  |         specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             const T two_pi = static_cast<T>(2.0f) * static_cast<T>(3.14159265358979323846); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |             auto yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |             auto zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto middle = zLength % 2  + zLength / 2; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int elementsPerThread = middle / TAD_THRESHOLD; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             int _threads = sd::math::nd4j_max<int>(1, elementsPerThread); | 
					
						
							|  |  |  |             _threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance()->maxThreads()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int span = (middle / _threads) + 8; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // we're enforcing even chunks, since it's mandatory for this algorithm
 | 
					
						
							|  |  |  |             span -= span % 2; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             //sd::random::RandomBuffer *buffer = reinterpret_cast<sd::random::RandomBuffer *> (state);
 | 
					
						
							|  |  |  |             sd::graph::RandomGenerator* rng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             const T mean = extraArguments[0]; | 
					
						
							|  |  |  |             const T stddev = extraArguments[1]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             const T epsilon = static_cast<T>(1e-5); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             auto func = PRAGMA_THREADS_FOR { | 
					
						
							| 
									
										
										
										
											2020-02-20 11:43:26 +03:00
										 |  |  |                 for (auto e = start; e < stop; e++) { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                     auto epm = e + middle; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                     // we need to get random values
 | 
					
						
							|  |  |  |                     T r0 = rng->relativeT<T>(e, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  |                     T r1 = rng->relativeT<T>(epm, epsilon, static_cast<T>(1.0f)); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                     T realMean0 = y == z ? mean : y[e * yEWS]; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                     auto z0 = (sd::math::nd4j_sqrt<T, T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T, T>(r0)) * | 
					
						
							|  |  |  |                                sd::math::nd4j_cos<T, T>(two_pi * r1)) * stddev + realMean0; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                     z[e * zEWS] = z0; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                     if (epm < zLength) { | 
					
						
							|  |  |  |                         T realMean1 = y == z ? mean : y[epm * yEWS]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                         auto z1 = (sd::math::nd4j_sqrt<T, T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T, T>(r0)) * | 
					
						
							|  |  |  |                                    sd::math::nd4j_sin<T, T>(two_pi * r1)) * stddev + realMean1; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                         z[epm * zEWS] = z1; | 
					
						
							|  |  |  |                     } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             samediff::Threads::parallel_for(func, 0, middle, 1, _threads); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     /**
 | 
					
						
							|  |  |  |     * This Op produces random values within [0..N], Distribuion is binomial | 
					
						
							|  |  |  |     */ | 
					
						
							|  |  |  |     template<typename T> | 
					
						
							|  |  |  |     class BinomialDistribution { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         method_XY | 
					
						
							|  |  |  |         method_X | 
					
						
							|  |  |  |         method_idx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static const bool requiresSpecial = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef __CUDACC__
 | 
					
						
							|  |  |  |         __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             int trials = (int) extraArguments[0]; | 
					
						
							|  |  |  |             T prob = extraArguments[1]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ Nd4jLong zLength; | 
					
						
							|  |  |  |             __shared__ int yEWS; | 
					
						
							|  |  |  |             __shared__ int zEWS; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator* rng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             __shared__ unsigned char *cB; | 
					
						
							|  |  |  |             __shared__ unsigned char *dB; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator *devRng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             if (threadIdx.x == 0) { | 
					
						
							|  |  |  |                 extern __shared__ unsigned char shmem[]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 rng = reinterpret_cast<sd::graph::RandomGenerator*>(shmem); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB = shmem; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 devRng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 dB = reinterpret_cast<unsigned char *> (state); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |                 yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |                 zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // using this loop instead of memcpy
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB[e] = dB[e]; | 
					
						
							| 
									
										
										
										
											2019-07-16 18:48:40 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int tid = blockIdx.x * blockDim.x + threadIdx.x; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { | 
					
						
							|  |  |  |                 int success = 0; | 
					
						
							|  |  |  |                 for (int t = 1; t <= trials; t++) { | 
					
						
							|  |  |  |                     T randVal = rng->relativeT<T>((e+1) * t); | 
					
						
							|  |  |  |                     if (y != z) { | 
					
						
							|  |  |  |                         // we're using external probs
 | 
					
						
							|  |  |  |                         prob = y[(t-1) * yEWS]; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if (randVal < prob) | 
					
						
							|  |  |  |                         success++; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 // if trials is set to 0, effectively we just have successful memset
 | 
					
						
							|  |  |  |                 z[e * zEWS] = static_cast<T>(success); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             int trials = (int) extraArguments[0]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             Nd4jLong zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |             auto zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int elementsPerThread = zLength / TAD_THRESHOLD; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             int _threads = sd::math::nd4j_max<int>(1, elementsPerThread); | 
					
						
							|  |  |  |             _threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance()->maxThreads()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             T prob = extraArguments[1]; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             sd::graph::RandomGenerator* rng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             auto func = PRAGMA_THREADS_FOR { | 
					
						
							| 
									
										
										
										
											2020-02-20 11:43:26 +03:00
										 |  |  |                 for (auto e = start; e < stop; e++) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     int success = 0; | 
					
						
							|  |  |  |                     for (int t = 1; t <= trials; t++) { | 
					
						
							|  |  |  |                         T randVal = rng->relativeT<T>((e+1) * t); | 
					
						
							|  |  |  |                         if (y != z) { | 
					
						
							|  |  |  |                             // we're using external probs
 | 
					
						
							|  |  |  |                             prob = y[(t-1) * yEWS]; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if (randVal < prob) | 
					
						
							|  |  |  |                             success++; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     // if trials is set to 0, effectively we just have successful memset
 | 
					
						
							|  |  |  |                     z[e * zEWS] = static_cast<T>(success); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     /**
 | 
					
						
							|  |  |  |     * This Op produces random values within [0..N], Distribuion is binomial | 
					
						
							|  |  |  |     */ | 
					
						
							|  |  |  |     template<typename T> | 
					
						
							|  |  |  |     class BinomialDistributionEx { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         method_XY | 
					
						
							|  |  |  |         method_X | 
					
						
							|  |  |  |         method_idx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static const bool requiresSpecial = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef __CUDACC__
 | 
					
						
							|  |  |  |         __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             int trials = (int) extraArguments[0]; | 
					
						
							|  |  |  |             T prob = extraArguments[1]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ Nd4jLong zLength; | 
					
						
							|  |  |  |             __shared__ int yEWS; | 
					
						
							|  |  |  |             __shared__ int zEWS; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator* rng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             __shared__ unsigned char *cB; | 
					
						
							|  |  |  |             __shared__ unsigned char *dB; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator *devRng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             if (threadIdx.x == 0) { | 
					
						
							|  |  |  |                 extern __shared__ unsigned char shmem[]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 rng = (sd::graph::RandomGenerator*) shmem; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB = shmem; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 devRng = reinterpret_cast<sd::graph::RandomGenerator*> (state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 dB = reinterpret_cast<unsigned char *> (state); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |                 yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |                 zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // using this loop instead of memcpy
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB[e] = dB[e]; | 
					
						
							| 
									
										
										
										
											2019-07-16 18:48:40 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int tid = blockIdx.x * blockDim.x + threadIdx.x; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { | 
					
						
							|  |  |  |                 int success = 0; | 
					
						
							|  |  |  |                 for (int t = 1; t <= trials; t++) { | 
					
						
							|  |  |  |                     T randVal = rng->relativeT<T>((e+1) * t); | 
					
						
							|  |  |  |                     if (y != z) { | 
					
						
							|  |  |  |                         // we're using external probs
 | 
					
						
							|  |  |  |                         prob = y[e * yEWS]; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if (randVal < prob) | 
					
						
							|  |  |  |                         success++; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 // if trials is set to 0, effectively we just have successful memset
 | 
					
						
							|  |  |  |                 z[e * zEWS] = (T) success; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static inline void specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             int trials = (int) extraArguments[0]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             Nd4jLong zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |             auto zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int elementsPerThread = zLength / TAD_THRESHOLD; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             int _threads = sd::math::nd4j_max<int>(1, elementsPerThread); | 
					
						
							|  |  |  |             _threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance()->maxThreads()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             T prob = extraArguments[1]; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             //sd::random::RandomBuffer *buffer = reinterpret_cast<sd::random::RandomBuffer *> (state);
 | 
					
						
							|  |  |  |             sd::graph::RandomGenerator* rng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             auto func = PRAGMA_THREADS_FOR { | 
					
						
							| 
									
										
										
										
											2020-02-20 11:43:26 +03:00
										 |  |  |                 for (auto e = start; e < stop; e++) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     int success = 0; | 
					
						
							|  |  |  |                     for (int t = 1; t <= trials; t++) { | 
					
						
							|  |  |  |                         T randVal = rng->relativeT<T>((e+1) * t); | 
					
						
							|  |  |  |                         if (y != z) { | 
					
						
							|  |  |  |                             // we're using external probs
 | 
					
						
							|  |  |  |                             prob = y[e * yEWS]; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if (randVal < prob) | 
					
						
							|  |  |  |                             success++; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     // if trials is set to 0, effectively we just have successful memset
 | 
					
						
							|  |  |  |                     z[e * zEWS] = static_cast<T>(success); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  |     }; | 
					
						
							| 
									
										
										
										
											2019-09-11 20:12:09 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | //////////////////////////////////////////////////////////////////////
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     // This Op produces random Gaussian values within [mean-2*stddev,mean+2*stddev]
 | 
					
						
							|  |  |  |     template<typename T> | 
					
						
							|  |  |  |     class TruncatedNormalDistribution { | 
					
						
							|  |  |  |     private: | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |         static inline _CUDA_HD T step(sd::graph::RandomGenerator* rng, T mean, T stddev, Nd4jLong e, Nd4jLong middle, T& z) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             auto epm = e + middle; | 
					
						
							|  |  |  |             const T two_pi = static_cast<T>(2.0f) * static_cast<T>(3.14159265358979323846); | 
					
						
							|  |  |  |             const T epsilon = static_cast<T>(1.e-5f); | 
					
						
							|  |  |  |             // we need to get random values
 | 
					
						
							|  |  |  |             T r0 = rng->relativeT<T>(e, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  |             T r1 = rng->relativeT<T>(epm, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             T realMean0 = mean; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             auto z0 =  (sd::math::nd4j_sqrt<T,T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T,T>(r0)) * sd::math::nd4j_cos<T,T>(two_pi * r1)) * stddev + realMean0; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             z = z0; | 
					
						
							|  |  |  |             if (epm < middle) { | 
					
						
							|  |  |  |                 T realMean1 = mean; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 auto z1 = (sd::math::nd4j_sqrt<T, T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T, T>(r0)) * | 
					
						
							|  |  |  |                            sd::math::nd4j_sin<T, T>(two_pi * r1)) * stddev + realMean1; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 z = z1; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             return z; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         method_XY | 
					
						
							|  |  |  |         method_X | 
					
						
							|  |  |  |         method_idx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static const bool requiresSpecial = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef __CUDACC__
 | 
					
						
							|  |  |  |         __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             __shared__ T epsilon; | 
					
						
							|  |  |  |             __shared__ T two_pi; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ Nd4jLong zLength; | 
					
						
							|  |  |  |             __shared__ Nd4jLong zEWS; | 
					
						
							|  |  |  |             __shared__ Nd4jLong yEWS; | 
					
						
							|  |  |  |             __shared__ T mean; | 
					
						
							|  |  |  |             __shared__ T stddev; | 
					
						
							|  |  |  |             __shared__ int step; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ T *tZ; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator* rng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             __shared__ unsigned char *cB; | 
					
						
							|  |  |  |             __shared__ unsigned char *dB; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator* devRng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             __shared__ Nd4jLong middle; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (threadIdx.x == 0) { | 
					
						
							|  |  |  |                 extern __shared__ unsigned char shmem[]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 rng = reinterpret_cast<sd::graph::RandomGenerator*>(shmem); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB = shmem; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 devRng = reinterpret_cast<sd::graph::RandomGenerator*> (state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 dB = reinterpret_cast<unsigned char *> (state); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 tZ = reinterpret_cast<T*>(shmem + sizeof(sd::graph::RandomGenerator)); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |                 zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  |                 yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 epsilon = static_cast<T>(1e-6f); | 
					
						
							|  |  |  |                 two_pi = static_cast<T>(2.0f) * static_cast<T>(3.14159265358979323846); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 mean = extraArguments[0]; | 
					
						
							|  |  |  |                 stddev = extraArguments[1]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 step = (blockDim.x * gridDim.x); | 
					
						
							|  |  |  |                 middle = zLength / 2 + (zLength % 2); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // using this loop instead of memcpy
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB[e] = dB[e]; | 
					
						
							| 
									
										
										
										
											2019-07-16 18:48:40 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int tid = blockIdx.x * blockDim.x + threadIdx.x; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             GaussianDistribution<T>::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             T ds = sd::math::nd4j_abs<T>(stddev) * static_cast<T>(2.0f); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             for (Nd4jLong e = tid; e < zLength; e += step) { | 
					
						
							|  |  |  |                 if (z[e] > mean + ds || z[e] < mean - ds) { | 
					
						
							|  |  |  |                     z[e] = TruncatedNormalDistribution<T>::step(rng, mean, stddev, e, middle, z[e]); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if (z[e] > mean + ds || z[e] < mean - ds) | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                         z[e] = mean + sd::DataTypeUtils::min<T>(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static inline void | 
					
						
							|  |  |  |         specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             GaussianDistribution<T>::specialOp(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); | 
					
						
							|  |  |  |             Nd4jLong zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |             //auto yEWS = shape::elementWiseStride(yShapeBuffer);
 | 
					
						
							|  |  |  |             //auto zEWS = shape::elementWiseStride(zShapeBuffer);
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             sd::graph::RandomGenerator* rng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             T mean = extraArguments[0]; | 
					
						
							|  |  |  |             T stddev = extraArguments[1]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             T ds = sd::math::nd4j_abs<T>(stddev) * (T) 2.0f; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             Nd4jLong middle = zLength / 2 + (zLength % 2); | 
					
						
							|  |  |  |             int elementsPerThread = middle / TAD_THRESHOLD; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             int _threads = sd::math::nd4j_max<int>(1, elementsPerThread); | 
					
						
							|  |  |  |             _threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance()->maxThreads()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             const T epsilon = static_cast<T>(1e-5); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             auto func = PRAGMA_THREADS_FOR { | 
					
						
							| 
									
										
										
										
											2020-02-20 11:43:26 +03:00
										 |  |  |                 for (auto e = start; e < stop; e++) { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                     if (z[e] > mean + ds || z[e] < mean - ds) { | 
					
						
							|  |  |  |                         z[e] = step(rng, mean, stddev, e, middle, z[e]); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                         if (z[e] > mean + ds || z[e] < mean - ds) | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                             z[e] = mean + sd::DataTypeUtils::min<T>(); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                     } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | // This Op produces random Log-normal distribution
 | 
					
						
							|  |  |  |  template<typename T> | 
					
						
							|  |  |  |     class LogNormalDistribution { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         method_XY | 
					
						
							|  |  |  |         method_X | 
					
						
							|  |  |  |         method_idx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static const bool requiresSpecial = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef __CUDACC__
 | 
					
						
							|  |  |  |         __device__ static inline void specialOpCuda(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             __shared__ T epsilon; | 
					
						
							|  |  |  |             __shared__ T two_pi; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ Nd4jLong zLength; | 
					
						
							|  |  |  |             __shared__ Nd4jLong zEWS; | 
					
						
							|  |  |  |             __shared__ Nd4jLong yEWS; | 
					
						
							|  |  |  |             __shared__ T mean; | 
					
						
							|  |  |  |             __shared__ T stddev; | 
					
						
							|  |  |  |             __shared__ int step; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             __shared__ T *tZ; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator* rng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             __shared__ unsigned char *cB; | 
					
						
							|  |  |  |             __shared__ unsigned char *dB; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             __shared__ sd::graph::RandomGenerator* devRng; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if (threadIdx.x == 0) { | 
					
						
							|  |  |  |                 extern __shared__ unsigned char shmem[]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 rng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB = shmem; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 devRng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 dB = reinterpret_cast<unsigned char *> (state); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 tZ = reinterpret_cast<T*>(shmem + sizeof(sd::graph::RandomGenerator)); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |                 zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  |                 yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 epsilon = static_cast<T>(1e-5); | 
					
						
							|  |  |  |                 two_pi = static_cast<T>(2.0f) * static_cast<T>(3.14159265358979323846); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 mean = extraArguments[0]; | 
					
						
							|  |  |  |                 stddev = extraArguments[1]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 step = (blockDim.x * gridDim.x); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // using this loop instead of memcpy
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 cB[e] = dB[e]; | 
					
						
							| 
									
										
										
										
											2019-07-16 18:48:40 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             __syncthreads(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int tid = blockIdx.x * blockDim.x + threadIdx.x; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (Nd4jLong e = tid; e < middle; e += step) { | 
					
						
							|  |  |  |                 auto epm = e + middle; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 // we need to get random values
 | 
					
						
							|  |  |  |                 T r0 = rng->relativeT<T>(e, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  |                 T r1 = rng->relativeT<T>(epm, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 T realMean = y == z ? mean : y[e * yEWS]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 z[e *zEWS] =  sd::math::nd4j_exp<T,T>((sd::math::nd4j_sqrt<T,T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T,T>(r0)) * sd::math::nd4j_cos<T,T>(two_pi * r1)) * stddev + realMean); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 if (epm < zLength) { | 
					
						
							|  |  |  |                     realMean = y == z ? mean : y[epm * yEWS]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                     z[epm *zEWS] =  sd::math::nd4j_exp<T,T>((sd::math::nd4j_sqrt<T,T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T,T>(r0)) * sd::math::nd4j_sin<T,T>(two_pi * r1)) * stddev + realMean); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         static inline void | 
					
						
							|  |  |  |         specialOp(Nd4jPointer state, T *x, Nd4jLong *xShapeBuffer, T *y, Nd4jLong *yShapeBuffer, T *z, Nd4jLong *zShapeBuffer, T *extraArguments) { | 
					
						
							|  |  |  |             const T two_pi = static_cast<T>(2.0f) * static_cast<T>(3.14159265358979323846); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             Nd4jLong zLength = shape::length(zShapeBuffer); | 
					
						
							|  |  |  |             auto yEWS = shape::elementWiseStride(yShapeBuffer); | 
					
						
							|  |  |  |             auto zEWS = shape::elementWiseStride(zShapeBuffer); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int elementsPerThread = middle / TAD_THRESHOLD; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             int _threads = sd::math::nd4j_max<int>(1, elementsPerThread); | 
					
						
							|  |  |  |             _threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance()->maxThreads()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             int span = (zLength / _threads) + 8; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // we're enforcing even chunks, since it's mandatory for this algorithm
 | 
					
						
							|  |  |  |             span -= span % 2; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | //            auto buffer = reinterpret_cast<sd::random::RandomBuffer *> (state);
 | 
					
						
							|  |  |  |             sd::graph::RandomGenerator* rng = reinterpret_cast<sd::graph::RandomGenerator*>(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             const T mean = extraArguments[0]; | 
					
						
							|  |  |  |             const T stddev = extraArguments[1]; | 
					
						
							|  |  |  |             const T epsilon = static_cast<T>(1e-5); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             auto func = PRAGMA_THREADS_FOR { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 PRAGMA_OMP_SIMD | 
					
						
							| 
									
										
										
										
											2020-02-20 11:43:26 +03:00
										 |  |  |                 for (auto e = start; e < stop; e++) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                     auto epm = e + middle; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     // we need to get random values
 | 
					
						
							|  |  |  |                     T r0 = rng->relativeT<T>(e, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  |                     T r1 = rng->relativeT<T>(epm, epsilon, static_cast<T>(1.0f)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     T realMean = y == z ? mean : y[e * yEWS]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                     z[e * zEWS] =  sd::math::nd4j_exp<T,T>((sd::math::nd4j_sqrt<T,T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T,T>(r0)) * sd::math::nd4j_cos<T,T>(two_pi * r1)) * stddev + realMean); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     if (epm < zLength) { | 
					
						
							|  |  |  |                         realMean = y == z ? mean : y[epm * yEWS]; | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                         z[epm * zEWS] =  sd::math::nd4j_exp<T,T>((sd::math::nd4j_sqrt<T,T>(static_cast<T>(-2.0f) * sd::math::nd4j_log<T,T>(r0)) * sd::math::nd4j_sin<T,T>(two_pi * r1)) * stddev + realMean); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             samediff::Threads::parallel_for(func, 0, middle, 1, _threads); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif //LIBND4J_SPECIAL_RANDOM_OPS_H
 |