| 
									
										
										
										
											2021-02-01 21:31:45 +09:00
										 |  |  | /* ******************************************************************************
 | 
					
						
							|  |  |  |  * | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							| 
									
										
										
										
											2021-02-01 21:31:45 +09:00
										 |  |  |  *  See the NOTICE file distributed with this work for additional | 
					
						
							|  |  |  |  *  information regarding copyright ownership. | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // @author raver119@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | #include <execution/Threads.h>
 | 
					
						
							|  |  |  | #include <execution/ThreadPool.h>
 | 
					
						
							|  |  |  | #include <vector>
 | 
					
						
							|  |  |  | #include <thread>
 | 
					
						
							|  |  |  | #include <helpers/logger.h>
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <math/templatemath.h>
 | 
					
						
							|  |  |  | #include <helpers/shape.h>
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  | namespace samediff { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) { | 
					
						
							|  |  |  |         // let's see how many threads we actually need first
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |         auto optimalThreads = sd::math::nd4j_max<uint64_t>(1, numberOfElements / 1024); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         // now return the smallest value
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |         return sd::math::nd4j_min<int>(optimalThreads, maxThreads); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Span3::Span3(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) { | 
					
						
							|  |  |  |         _startX = startX; | 
					
						
							|  |  |  |         _startY = startY; | 
					
						
							|  |  |  |         _startZ = startZ; | 
					
						
							|  |  |  |         _stopX = stopX; | 
					
						
							|  |  |  |         _stopY = stopY; | 
					
						
							|  |  |  |         _stopZ = stopZ; | 
					
						
							|  |  |  |         _incX = incX; | 
					
						
							|  |  |  |         _incY = incY; | 
					
						
							|  |  |  |         _incZ = incZ; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) { | 
					
						
							|  |  |  |         switch (loop) { | 
					
						
							|  |  |  |             case 1: { | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     auto span = (stopX - startX) / numThreads; | 
					
						
							|  |  |  |                     auto s = span * threadID; | 
					
						
							|  |  |  |                     auto e = s + span; | 
					
						
							|  |  |  |                     if (threadID == numThreads - 1) | 
					
						
							|  |  |  |                         e = stopX; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 break; | 
					
						
							|  |  |  |             case 2: { | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     auto span = (stopY - startY) / numThreads; | 
					
						
							|  |  |  |                     auto s = span * threadID; | 
					
						
							|  |  |  |                     auto e = s + span; | 
					
						
							|  |  |  |                     if (threadID == numThreads - 1) | 
					
						
							|  |  |  |                         e = stopY; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 break; | 
					
						
							|  |  |  |             case 3: { | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     auto span = (stopZ - startZ) / numThreads; | 
					
						
							|  |  |  |                     auto s = span * threadID; | 
					
						
							|  |  |  |                     auto e = s + span; | 
					
						
							|  |  |  |                     if (threadID == numThreads - 1) | 
					
						
							|  |  |  |                         e = stopZ; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 break; | 
					
						
							|  |  |  |             default: | 
					
						
							|  |  |  |                 throw std::runtime_error(""); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return Span3(startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Span::Span(int64_t startX, int64_t stopX, int64_t incX) { | 
					
						
							|  |  |  |         _startX = startX; | 
					
						
							|  |  |  |         _stopX = stopX; | 
					
						
							|  |  |  |         _incX = incX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Span Span::build(uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX) { | 
					
						
							|  |  |  |         auto span = (stopX - startX) / numThreads; | 
					
						
							|  |  |  |         auto s = span * threadID; | 
					
						
							|  |  |  |         auto e = s + span; | 
					
						
							|  |  |  |         if (threadID == numThreads - 1) | 
					
						
							|  |  |  |             e = stopX; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return Span(s, e, incX); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Span2::Span2(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) { | 
					
						
							|  |  |  |         _startX = startX; | 
					
						
							|  |  |  |         _startY = startY; | 
					
						
							|  |  |  |         _stopX = stopX; | 
					
						
							|  |  |  |         _stopY = stopY; | 
					
						
							|  |  |  |         _incX = incX; | 
					
						
							|  |  |  |         _incY = incY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Span2 Span2::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         switch (loop) { | 
					
						
							|  |  |  |             case 1: { | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     auto span = (stopX - startX) / numThreads; | 
					
						
							|  |  |  |                     auto s = span * threadID; | 
					
						
							|  |  |  |                     auto e = s + span; | 
					
						
							|  |  |  |                     if (threadID == numThreads - 1) | 
					
						
							|  |  |  |                         e = stopX; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     return Span2(s, e, incX, startY, stopY, incY); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 break; | 
					
						
							|  |  |  |             case 2: { | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     auto span = (stopY - startY) / numThreads; | 
					
						
							|  |  |  |                     auto s = span * threadID; | 
					
						
							|  |  |  |                     auto e = s + span; | 
					
						
							|  |  |  |                     if (threadID == numThreads - 1) | 
					
						
							|  |  |  |                         e = stopY; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |                     return Span2(startX, stopX, incX, s, e, incY); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 break; | 
					
						
							|  |  |  |             default: | 
					
						
							|  |  |  |                 throw std::runtime_error(""); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span::startX() const { | 
					
						
							|  |  |  |         return _startX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span::stopX() const { | 
					
						
							|  |  |  |         return _stopX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span::incX() const { | 
					
						
							|  |  |  |         return _incX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span2::startX() const { | 
					
						
							|  |  |  |         return _startX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span2::startY() const { | 
					
						
							|  |  |  |         return _startY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span2::stopX() const { | 
					
						
							|  |  |  |         return _stopX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span2::stopY() const { | 
					
						
							|  |  |  |         return _stopY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span2::incX() const { | 
					
						
							|  |  |  |         return _incX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span2::incY() const { | 
					
						
							|  |  |  |         return _incY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::startX() const { | 
					
						
							|  |  |  |         return _startX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::startY() const { | 
					
						
							|  |  |  |         return _startY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::startZ() const { | 
					
						
							|  |  |  |         return _startZ; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::stopX() const { | 
					
						
							|  |  |  |         return _stopX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::stopY() const { | 
					
						
							|  |  |  |         return _stopY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::stopZ() const { | 
					
						
							|  |  |  |         return _stopZ; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::incX() const { | 
					
						
							|  |  |  |         return _incX; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::incY() const { | 
					
						
							|  |  |  |         return _incY; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Span3::incZ() const { | 
					
						
							|  |  |  |         return _incZ; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int ThreadsHelper::pickLoop2d(int numThreads, uint64_t itersX, uint64_t itersY) { | 
					
						
							|  |  |  |         // if one of dimensions is definitely too small - we just pick the other one
 | 
					
						
							|  |  |  |         if (itersX < numThreads && itersY >= numThreads) | 
					
						
							|  |  |  |             return 2; | 
					
						
							|  |  |  |         if (itersY < numThreads && itersX >= numThreads) | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // next step - we pick the most balanced dimension
 | 
					
						
							|  |  |  |         auto remX = itersX % numThreads; | 
					
						
							|  |  |  |         auto remY = itersY % numThreads; | 
					
						
							|  |  |  |         auto splitY = itersY / numThreads; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution
 | 
					
						
							|  |  |  |         if (remX == 0) | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         if (remY == 0) | 
					
						
							|  |  |  |             return 2; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // if there's no loop without a remainder - we're picking one with smaller remainder
 | 
					
						
							|  |  |  |         if (remX < remY) | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         if (remY < remX && splitY >= 64) // we don't want too small splits over last dimension, or vectorization will fail
 | 
					
						
							|  |  |  |             return 2; | 
					
						
							|  |  |  |         // if loops are equally sized - give the preference to the first thread
 | 
					
						
							|  |  |  |         return 1; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     static int threads_(int maxThreads, uint64_t elements) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (elements == maxThreads) { | 
					
						
							|  |  |  |             return maxThreads; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         else if (elements > maxThreads) { | 
					
						
							|  |  |  |             // if we have full load across thread, or at least half of threads can be utilized
 | 
					
						
							|  |  |  |             auto rem = elements % maxThreads; | 
					
						
							|  |  |  |             if (rem == 0 || rem >= maxThreads / 3) | 
					
						
							|  |  |  |                 return maxThreads; | 
					
						
							|  |  |  |             else | 
					
						
							|  |  |  |                 return threads_(maxThreads - 1, elements); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         else if (elements < maxThreads) { | 
					
						
							|  |  |  |             return elements; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return 1; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int ThreadsHelper::numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y) { | 
					
						
							|  |  |  |         // in some cases there's nothing to think about, part 1
 | 
					
						
							|  |  |  |         if (iters_x < maxThreads && iters_y < maxThreads) | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             return sd::math::nd4j_max<int>(iters_x, iters_y); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         auto remX = iters_x % maxThreads; | 
					
						
							|  |  |  |         auto remY = iters_y % maxThreads; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // in some cases there's nothing to think about, part 2
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0)) | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             return maxThreads; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // at this point we suppose that there's no loop perfectly matches number of our threads
 | 
					
						
							|  |  |  |         // so let's pick something as equal as possible
 | 
					
						
							|  |  |  |         if (iters_x > maxThreads || iters_y > maxThreads) | 
					
						
							|  |  |  |             return maxThreads; | 
					
						
							|  |  |  |         else | 
					
						
							|  |  |  |             return numberOfThreads2d(maxThreads - 1, iters_x, iters_y); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int ThreadsHelper::numberOfThreads3d(int maxThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) { | 
					
						
							|  |  |  |         // we don't want to run underloaded threads
 | 
					
						
							|  |  |  |         if (itersX * itersY * itersZ <= 32) | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto remX = itersX % maxThreads; | 
					
						
							|  |  |  |         auto remY = itersY % maxThreads; | 
					
						
							|  |  |  |         auto remZ = itersZ % maxThreads; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // if we have perfect balance across one of dimensions - just go for it
 | 
					
						
							|  |  |  |         if ((itersX >= maxThreads && remX == 0) || (itersY >= maxThreads && remY == 0) || (itersZ >= maxThreads && remZ == 0)) | 
					
						
							|  |  |  |             return maxThreads; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int threadsX = 0, threadsY = 0, threadsZ = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // now we look into possible number of
 | 
					
						
							|  |  |  |         threadsX = threads_(maxThreads, itersX); | 
					
						
							|  |  |  |         threadsY = threads_(maxThreads, itersY); | 
					
						
							|  |  |  |         threadsZ = threads_(maxThreads, itersZ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // we want to split as close to outer loop as possible, so checking it out first
 | 
					
						
							|  |  |  |         if (threadsX >= threadsY && threadsX >= threadsZ) | 
					
						
							|  |  |  |             return threadsX; | 
					
						
							|  |  |  |         else if (threadsY >= threadsX && threadsY >= threadsZ) | 
					
						
							|  |  |  |             return threadsY; | 
					
						
							|  |  |  |         else if (threadsZ >= threadsX && threadsZ >= threadsY) | 
					
						
							|  |  |  |             return threadsZ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return 1; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int ThreadsHelper::pickLoop3d(int numThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) { | 
					
						
							|  |  |  |         auto remX = itersX % numThreads; | 
					
						
							|  |  |  |         auto remY = itersY % numThreads; | 
					
						
							|  |  |  |         auto remZ = itersZ % numThreads; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto splitX = itersX / numThreads; | 
					
						
							|  |  |  |         auto splitY = itersY / numThreads; | 
					
						
							|  |  |  |         auto splitZ = itersZ / numThreads; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution
 | 
					
						
							|  |  |  |         if (remX == 0) | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         else if (remY == 0) | 
					
						
							|  |  |  |             return 2; | 
					
						
							|  |  |  |         else if (remZ == 0) // TODO: we don't want too smal splits over last dimension? or we do?
 | 
					
						
							|  |  |  |             return 3; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (itersX > numThreads) | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         else if (itersY > numThreads) | 
					
						
							|  |  |  |             return 2; | 
					
						
							|  |  |  |         else if (itersZ > numThreads) | 
					
						
							|  |  |  |             return 3; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return 1; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { | 
					
						
							|  |  |  |         if (start > stop) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got start > stop"); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         auto delta = (stop - start); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if (numThreads > delta) | 
					
						
							|  |  |  |             numThreads = delta; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (numThreads == 0) | 
					
						
							|  |  |  |             return 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // shortcut
 | 
					
						
							|  |  |  |         if (numThreads == 1) { | 
					
						
							|  |  |  |             function(0, start, stop, increment); | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         if (ticket != nullptr) { | 
					
						
							|  |  |  |             // if we got our threads - we'll run our jobs here
 | 
					
						
							|  |  |  |             auto span = delta / numThreads; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             for (uint32_t e = 0; e < numThreads; e++) { | 
					
						
							|  |  |  |                 auto start_ = span * e + start; | 
					
						
							|  |  |  |                 auto stop_  = start_ + span; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 // last thread will process tail
 | 
					
						
							|  |  |  |                 if (e == numThreads - 1) | 
					
						
							|  |  |  |                     stop_ = stop; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 // putting the task into the queue for a given thread
 | 
					
						
							|  |  |  |                 ticket->enqueue(e, numThreads, function, start_, stop_, increment); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             // block and wait till all threads finished the job
 | 
					
						
							|  |  |  |             ticket->waitAndRelease(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // we tell that parallelism request succeeded
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             return numThreads; | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         } else { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             // if there were no threads available - we'll execute function right within current thread
 | 
					
						
							|  |  |  |             function(0, start, stop, increment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // we tell that parallelism request declined
 | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { | 
					
						
							|  |  |  |         if (start > stop) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got start > stop"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto delta = (stop - start); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // in some cases we just fire func as is
 | 
					
						
							|  |  |  |         if (delta == 0 || numThreads == 1) { | 
					
						
							|  |  |  |             function(0, start, stop, increment); | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto numElements = delta / increment; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // we decide what's optimal number of threads we need here, and execute it in parallel_tad.
 | 
					
						
							|  |  |  |         numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); | 
					
						
							|  |  |  |         return parallel_tad(function, start, stop, increment, numThreads); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int Threads::parallel_for(FUNC_2D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, uint64_t numThreads, bool debug) { | 
					
						
							|  |  |  |         if (startX > stopX) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got startX > stopX"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (startY > stopY) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got startY > stopY"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // number of elements per loop
 | 
					
						
							|  |  |  |         auto delta_x = (stopX - startX); | 
					
						
							|  |  |  |         auto delta_y = (stopY - startY); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // number of iterations per loop
 | 
					
						
							|  |  |  |         auto itersX = delta_x / incX; | 
					
						
							|  |  |  |         auto itersY = delta_y / incY; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // total number of iterations
 | 
					
						
							|  |  |  |         auto iters_t = itersX * itersY; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // we are checking the case of number of requested threads was smaller
 | 
					
						
							|  |  |  |         numThreads = ThreadsHelper::numberOfThreads2d(numThreads, itersX, itersY); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // basic shortcut for no-threading cases
 | 
					
						
							|  |  |  |         if (numThreads == 1) { | 
					
						
							|  |  |  |             function(0, startX, stopX, incX, startY, stopY, incY); | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // We have couple of scenarios:
 | 
					
						
							|  |  |  |         // either we split workload along 1st loop, or 2nd
 | 
					
						
							|  |  |  |         auto splitLoop = ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // for debug mode we execute things inplace, without any threads
 | 
					
						
							|  |  |  |         if (debug) { | 
					
						
							|  |  |  |             for (int e = 0; e < numThreads; e++) { | 
					
						
							|  |  |  |                 auto span = Span2::build(splitLoop, e, numThreads, startX, stopX, incX, startY, stopY, incY); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 function(e, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY()); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // but we still mimic multithreaded execution
 | 
					
						
							|  |  |  |             return numThreads; | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         } else { | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |             auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             if (ticket != nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 for (int e = 0; e < numThreads; e++) { | 
					
						
							|  |  |  |                     auto threadId = numThreads - e - 1; | 
					
						
							|  |  |  |                     auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY()); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 } | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 // block until all threads finish their job
 | 
					
						
							|  |  |  |                 ticket->waitAndRelease(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 return numThreads; | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             } else { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |                 // if there were no threads available - we'll execute function right within current thread
 | 
					
						
							|  |  |  |                 function(0, startX, stopX, incX, startY, stopY, incY); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 // we tell that parallelism request declined
 | 
					
						
							|  |  |  |                 return 1; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int Threads::parallel_for(FUNC_3D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ, uint64_t numThreads) { | 
					
						
							|  |  |  |         if (startX > stopX) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got startX > stopX"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (startY > stopY) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got startY > stopY"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (startZ > stopZ) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got startZ > stopZ"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto delta_x = stopX - startX; | 
					
						
							|  |  |  |         auto delta_y = stopY - startY; | 
					
						
							|  |  |  |         auto delta_z = stopZ - startZ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto itersX = delta_x / incX; | 
					
						
							|  |  |  |         auto itersY = delta_y / incY; | 
					
						
							|  |  |  |         auto itersZ = delta_z / incZ; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-19 16:50:08 +03:00
										 |  |  |         numThreads = ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |         if (numThreads == 1) { | 
					
						
							|  |  |  |             // loop is too small - executing function as is
 | 
					
						
							|  |  |  |             function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         auto ticket = ThreadPool::getInstance().tryAcquire(numThreads); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         if (ticket != nullptr) { | 
					
						
							|  |  |  |             auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             for (int e = 0; e < numThreads; e++) { | 
					
						
							|  |  |  |                 auto thread_id = numThreads - e - 1; | 
					
						
							|  |  |  |                 auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ()); | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             // block until we're done
 | 
					
						
							|  |  |  |             ticket->waitAndRelease(); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             // we tell that parallelism request succeeded
 | 
					
						
							|  |  |  |             return numThreads; | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             // if there were no threads available - we'll execute function right within current thread
 | 
					
						
							|  |  |  |             function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             // we tell that parallelism request declined
 | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) { | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         if (ticket != nullptr) { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             // submit tasks one by one
 | 
					
						
							|  |  |  |             for (uint64_t e = 0; e < numThreads - 1; e++) | 
					
						
							|  |  |  |                 ticket->enqueue(e, numThreads, function); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             function(numThreads - 1, numThreads); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             ticket->waitAndRelease(); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return numThreads; | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         } else { | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |             // if there's no threads available - we'll execute function sequentially one by one
 | 
					
						
							|  |  |  |             for (uint64_t e = 0; e < numThreads; e++) | 
					
						
							|  |  |  |                 function(e, numThreads); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return numThreads; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return numThreads; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t Threads::parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) { | 
					
						
							|  |  |  |         if (start > stop) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_long got start > stop"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto delta = (stop - start); | 
					
						
							|  |  |  |         if (delta == 0 || numThreads == 1) | 
					
						
							|  |  |  |             return function(0, start, stop, increment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto numElements = delta / increment; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // we decide what's optimal number of threads we need here, and execute it
 | 
					
						
							|  |  |  |         numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); | 
					
						
							|  |  |  |         if (numThreads == 1) | 
					
						
							|  |  |  |             return function(0, start, stop, increment); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:21:44 +03:00
										 |  |  |         if (ticket == nullptr) | 
					
						
							|  |  |  |             return function(0, start, stop, increment); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         // create temporary array
 | 
					
						
							|  |  |  |         int64_t intermediatery[256]; | 
					
						
							| 
									
										
										
										
											2020-05-08 20:59:39 +03:00
										 |  |  |         auto span = (numElements / numThreads) - (numElements % numThreads); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:21:44 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         // execute threads in parallel
 | 
					
						
							|  |  |  |         for (uint32_t e = 0; e < numThreads; e++) { | 
					
						
							|  |  |  |             auto start_ = span * e + start; | 
					
						
							|  |  |  |             auto stop_ = span * (e + 1) + start; | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             if (e == numThreads - 1) | 
					
						
							|  |  |  |                 intermediatery[e] = function(e, start_, stop, increment); | 
					
						
							|  |  |  |             else | 
					
						
							|  |  |  |                 ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-03-09 08:21:44 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         ticket->waitAndRelease(); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         // aggregate results in single thread
 | 
					
						
							|  |  |  |         for (uint64_t e = 1; e < numThreads; e++) | 
					
						
							|  |  |  |             intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // return accumulated result
 | 
					
						
							|  |  |  |         return intermediatery[0]; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     double Threads::parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) { | 
					
						
							|  |  |  |         if (start > stop) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_long got start > stop"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto delta = (stop - start); | 
					
						
							|  |  |  |         if (delta == 0 || numThreads == 1) | 
					
						
							|  |  |  |             return function(0, start, stop, increment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto numElements = delta / increment; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // we decide what's optimal number of threads we need here, and execute it
 | 
					
						
							|  |  |  |         numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); | 
					
						
							|  |  |  |         if (numThreads == 1) | 
					
						
							|  |  |  |             return function(0, start, stop, increment); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:21:44 +03:00
										 |  |  |         if (ticket == nullptr) | 
					
						
							|  |  |  |             return function(0, start, stop, increment); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         // create temporary array
 | 
					
						
							|  |  |  |         double intermediatery[256]; | 
					
						
							| 
									
										
										
										
											2020-05-08 20:59:39 +03:00
										 |  |  |         auto span = (numElements / numThreads) - (numElements % numThreads); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:21:44 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         // execute threads in parallel
 | 
					
						
							|  |  |  |         for (uint32_t e = 0; e < numThreads; e++) { | 
					
						
							|  |  |  |             auto start_ = span * e + start; | 
					
						
							|  |  |  |             auto stop_ = span * (e + 1) + start; | 
					
						
							| 
									
										
										
										
											2020-03-09 08:21:44 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             if (e == numThreads - 1) | 
					
						
							|  |  |  |                 intermediatery[e] = function(e, start_, stop, increment); | 
					
						
							|  |  |  |             else | 
					
						
							|  |  |  |                 ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         ticket->waitAndRelease(); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         // aggregate results in single thread
 | 
					
						
							|  |  |  |         for (uint64_t e = 1; e < numThreads; e++) | 
					
						
							|  |  |  |             intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // return accumulated result
 | 
					
						
							|  |  |  |         return intermediatery[0]; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |     int  Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size , uint32_t req_numThreads) { | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  |         if (start > stop) | 
					
						
							|  |  |  |             throw std::runtime_error("Threads::parallel_for got start > stop"); | 
					
						
							|  |  |  |         auto num_elements = (stop - start); | 
					
						
							|  |  |  |         //this way we preserve increment starts offset
 | 
					
						
							|  |  |  |         //so we will parition considering delta but not total elements
 | 
					
						
							|  |  |  |         auto delta = (stop - start) / increment; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // in some cases we just fire func as is
 | 
					
						
							|  |  |  |         if (delta == 0 || req_numThreads == 1) { | 
					
						
							|  |  |  |             function(0, start, stop, increment); | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         int numThreads = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size)); | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if (adjusted_numThreads > delta) | 
					
						
							|  |  |  |             adjusted_numThreads = delta; | 
					
						
							|  |  |  |         // shortcut
 | 
					
						
							|  |  |  |         if (adjusted_numThreads <= 1) { | 
					
						
							|  |  |  |             function(0, start, stop, increment); | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         //take span as ceil  
 | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  |         auto spand = std::ceil((double)delta / (double)adjusted_numThreads); | 
					
						
							|  |  |  |         numThreads = static_cast<int>(std::ceil((double)delta / spand)); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         auto span  = static_cast<Nd4jLong>(spand); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         auto ticket = samediff::ThreadPool::getInstance().tryAcquire(numThreads); | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |         if (ticket != nullptr) { | 
					
						
							|  |  |  |             //tail_add is additional value of the last part
 | 
					
						
							|  |  |  |             //it could be negative or positive
 | 
					
						
							|  |  |  |             //we will spread that value across
 | 
					
						
							|  |  |  |             auto tail_add = delta - numThreads * span; | 
					
						
							|  |  |  |             Nd4jLong begin = 0; | 
					
						
							|  |  |  |             Nd4jLong end = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             //we will try enqueu bigger parts first
 | 
					
						
							|  |  |  |             decltype(span) span1, span2; | 
					
						
							|  |  |  |             int last = 0; | 
					
						
							|  |  |  |             if (tail_add >= 0) { | 
					
						
							|  |  |  |                 //for span == 1  , tail_add is  0 
 | 
					
						
							|  |  |  |                 last = tail_add; | 
					
						
							|  |  |  |                 span1 = span + 1; | 
					
						
							|  |  |  |                 span2 = span; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             else { | 
					
						
							|  |  |  |                 last = numThreads + tail_add;// -std::abs(tail_add);
 | 
					
						
							|  |  |  |                 span1 = span; | 
					
						
							|  |  |  |                 span2 = span - 1; | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             for (int i = 0; i < last; i++) { | 
					
						
							|  |  |  |                 end = begin + span1 * increment; | 
					
						
							|  |  |  |                 // putting the task into the queue for a given thread
 | 
					
						
							|  |  |  |                 ticket->enqueue(i, numThreads, function, begin, end, increment); | 
					
						
							|  |  |  |                 begin = end; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for (int i = last; i < numThreads - 1; i++) { | 
					
						
							|  |  |  |                 end = begin + span2 * increment; | 
					
						
							|  |  |  |                 // putting the task into the queue for a given thread
 | 
					
						
							|  |  |  |                 ticket->enqueue(i, numThreads, function, begin, end, increment); | 
					
						
							|  |  |  |                 begin = end; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             //for last one enqueue last offset as stop
 | 
					
						
							|  |  |  |             //we need it in case our ((stop-start) % increment ) > 0
 | 
					
						
							|  |  |  |             ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, increment); | 
					
						
							|  |  |  |             // block and wait till all threads finished the job
 | 
					
						
							|  |  |  |             ticket->waitAndRelease(); | 
					
						
							|  |  |  |             // we tell that parallelism request succeeded
 | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  |             return numThreads; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         else { | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |             // if there were no threads available - we'll execute function right within current thread
 | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  |             function(0, start, stop, increment); | 
					
						
							|  |  |  |             // we tell that parallelism request declined
 | 
					
						
							|  |  |  |             return 1; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-02-08 16:31:30 +04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  | } |