Revert "OpenMP Threads execution (#297)" (#299)

This reverts commit dd2043ef48.
master
raver119 2020-03-09 08:22:49 +03:00 committed by GitHub
parent dd2043ef48
commit 57210b936c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
117 changed files with 593 additions and 890 deletions

View File

@ -21,9 +21,9 @@ if (SD_CUDA)
enable_language(CUDA) enable_language(CUDA)
set(CMAKE_CUDA_STANDARD 11) set(CMAKE_CUDA_STANDARD 11)
set(DEFAULT_ENGINE "sd::ENGINE_CUDA") set(DEFAULT_ENGINE "samediff::ENGINE_CUDA")
else() else()
set(DEFAULT_ENGINE "sd::ENGINE_CPU") set(DEFAULT_ENGINE "samediff::ENGINE_CPU")
endif() endif()
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively # MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively

View File

@ -56,7 +56,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };
sd::Threads::parallel_for(func, 0, length); samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
@ -114,7 +114,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };
sd::Threads::parallel_for(func, 0, length); samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
@ -142,7 +142,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };
sd::Threads::parallel_for(func, 0, length); samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
} }
@ -168,7 +168,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
}; };
sd::Threads::parallel_for(func, 0, length); samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
} }

View File

@ -515,7 +515,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost(); tickWriteHost();
syncToDevice(); syncToDevice();
@ -582,7 +582,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::stri
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost(); tickWriteHost();
@ -648,7 +648,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
} }
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost(); tickWriteHost();
syncToDevice(); syncToDevice();
@ -714,7 +714,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
} }
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost(); tickWriteHost();
syncToDevice(); syncToDevice();
@ -780,7 +780,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
} }
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost(); tickWriteHost();
syncToDevice(); syncToDevice();
@ -846,7 +846,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
} }
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost(); tickWriteHost();
syncToDevice(); syncToDevice();
@ -2393,7 +2393,7 @@ NDArray NDArray::asS() const {
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
registerPrimaryUse({ &res }, { this }); registerPrimaryUse({ &res }, { this });
@ -3466,7 +3466,7 @@ NDArray NDArray::dup(const char newOrder) const {
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
return NDArray(getShapeAsVector(), strings, dataType(), getContext()); return NDArray(getShapeAsVector(), strings, dataType(), getContext());
} }
@ -3479,7 +3479,7 @@ NDArray NDArray::dup(const char newOrder) const {
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
return NDArray(getShapeAsVector(), strings, dataType(), getContext()); return NDArray(getShapeAsVector(), strings, dataType(), getContext());
} }
@ -3491,7 +3491,7 @@ NDArray NDArray::dup(const char newOrder) const {
} }
}; };
sd::Threads::parallel_for(func, 0, lengthOf(), 1); samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
return NDArray(getShapeAsVector(), strings, dataType(), getContext()); return NDArray(getShapeAsVector(), strings, dataType(), getContext());
} }

View File

@ -115,7 +115,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
} }
}; };
sd::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
@ -159,7 +159,7 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
} }
}; };
sd::Threads::parallel_for(func, 0, length); samediff::Threads::parallel_for(func, 0, length);
} }
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
@ -272,7 +272,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
} }
}; };
sd::Threads::parallel_for(func, 0, resultLen); samediff::Threads::parallel_for(func, 0, resultLen);
} }
else { else {
@ -284,7 +284,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
} }
}; };
sd::Threads::parallel_for(func, 0, resultLen); samediff::Threads::parallel_for(func, 0, resultLen);
} }
result.tickWriteHost(); result.tickWriteHost();
return result; return result;
@ -397,7 +397,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
} }
}; };
sd::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -26,7 +26,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
z[e] = func(f[e], s[e], t[e]); z[e] = func(f[e], s[e], t[e]);
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
@ -40,7 +40,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
@ -54,7 +54,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} }
} }
} }
@ -97,7 +97,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
z[e] = func(f[e], s[e]); z[e] = func(f[e], s[e]);
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
@ -110,7 +110,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
@ -123,7 +123,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} }
} }
} }
@ -160,7 +160,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
z[e] = func(f[e]); z[e] = func(f[e]);
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
@ -172,7 +172,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
@ -184,7 +184,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} }
} }
} }
@ -221,7 +221,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
z[e] = func(e, f[e]); z[e] = func(e, f[e]);
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
@ -233,7 +233,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
@ -245,7 +245,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} }
} }
} }
@ -287,7 +287,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
z[e] = func((Nd4jLong) e, f[e], s[e]); z[e] = func((Nd4jLong) e, f[e], s[e]);
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
@ -300,7 +300,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
@ -313,7 +313,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
} }
}; };
sd::Threads::parallel_for(loop, 0, _length); samediff::Threads::parallel_for(loop, 0, _length);
} }
} }
} }

View File

@ -27,7 +27,7 @@
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
namespace sd { namespace samediff {
template <typename T> template <typename T>
class BlockingQueue { class BlockingQueue {
private: private:

View File

@ -29,7 +29,7 @@
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
namespace sd { namespace samediff {
/** /**
* This class is suited for passing functions to execution threads without queues * This class is suited for passing functions to execution threads without queues
*/ */

View File

@ -27,7 +27,7 @@
#include <condition_variable> #include <condition_variable>
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
namespace sd { namespace samediff {
class CallableWithArguments { class CallableWithArguments {
FUNC_DO _function_do; FUNC_DO _function_do;
FUNC_1D _function_1d; FUNC_1D _function_1d;

View File

@ -21,7 +21,7 @@
#ifndef SD_ENGINE_H #ifndef SD_ENGINE_H
#define SD_ENGINE_H #define SD_ENGINE_H
namespace sd { namespace samediff {
enum Engine { enum Engine {
ENGINE_CPU = 0, ENGINE_CPU = 0,
ENGINE_CUDA = 1, ENGINE_CUDA = 1,

View File

@ -21,7 +21,7 @@
#ifndef SD_EXECUTIONMODE_H #ifndef SD_EXECUTIONMODE_H
#define SD_EXECUTIONMODE_H #define SD_EXECUTIONMODE_H
namespace sd { namespace samediff {
enum ExecutionMode { enum ExecutionMode {
MODE_UNDEFINED = 0, MODE_UNDEFINED = 0,
MODE_TRAINING = 1, MODE_TRAINING = 1,

View File

@ -32,7 +32,7 @@
#include <execution/Ticket.h> #include <execution/Ticket.h>
#include <queue> #include <queue>
namespace sd { namespace samediff {
class ND4J_EXPORT ThreadPool { class ND4J_EXPORT ThreadPool {
private: private:
static ThreadPool* _INSTANCE; static ThreadPool* _INSTANCE;

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// //
#ifndef SAMEDIFF_THREADS_H #ifndef SAMEDIFF_THREADS_H
#define SAMEDIFF_THREADS_H #define SAMEDIFF_THREADS_H
@ -26,7 +26,7 @@
#include <system/Environment.h> #include <system/Environment.h>
#include <system/op_enums.h> #include <system/op_enums.h>
namespace sd { namespace samediff {
class ND4J_EXPORT ThreadsHelper { class ND4J_EXPORT ThreadsHelper {
public: public:
static int numberOfThreads(int maxThreads, uint64_t numberOfElements); static int numberOfThreads(int maxThreads, uint64_t numberOfElements);
@ -95,14 +95,6 @@ namespace sd {
}; };
class ND4J_EXPORT Threads { class ND4J_EXPORT Threads {
#ifdef _OPENMP
public:
static std::mutex gThreadmutex;
static uint64_t _nFreeThreads;
static bool tryAcquire(int numThreads);
static bool freeThreads(int numThreads);
#endif
public: public:
/** /**
* This function executes 1 dimensional loop for a given number of threads * This function executes 1 dimensional loop for a given number of threads

View File

@ -28,7 +28,7 @@
#include <atomic> #include <atomic>
#include <mutex> #include <mutex>
namespace sd { namespace samediff {
class ND4J_EXPORT Ticket { class ND4J_EXPORT Ticket {
private: private:
bool _acquired = false; bool _acquired = false;

View File

@ -22,7 +22,7 @@
#include <execution/CallableWithArguments.h> #include <execution/CallableWithArguments.h>
#include <thread> #include <thread>
namespace sd { namespace samediff {
template <typename T> template <typename T>
BlockingQueue<T>::BlockingQueue(int queueSize) { BlockingQueue<T>::BlockingQueue(int queueSize) {
_size = 0; _size = 0;

View File

@ -21,7 +21,7 @@
#include <execution/CallableInterface.h> #include <execution/CallableInterface.h>
#include <helpers/logger.h> #include <helpers/logger.h>
namespace sd { namespace samediff {
CallableInterface::CallableInterface() { CallableInterface::CallableInterface() {
// initial state is available // initial state is available
_available = true; _available = true;

View File

@ -20,7 +20,7 @@
#include <execution/CallableWithArguments.h> #include <execution/CallableWithArguments.h>
namespace sd { namespace samediff {
CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) { CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) {
_function_do = func; _function_do = func;
_finished = false; _finished = false;

View File

@ -26,7 +26,7 @@
//#include <windows.h> //#include <windows.h>
#endif #endif
namespace sd { namespace samediff {
// this function executed once per thread, it polls functions from queue, and executes them via wrapper // this function executed once per thread, it polls functions from queue, and executes them via wrapper
static void executionLoop_(int thread_id, BlockingQueue<CallableWithArguments*> *queue) { static void executionLoop_(int thread_id, BlockingQueue<CallableWithArguments*> *queue) {
@ -183,7 +183,7 @@ namespace sd {
} }
} }
void ThreadPool::release(sd::Ticket *ticket) { void ThreadPool::release(samediff::Ticket *ticket) {
// returning ticket back to the queue // returning ticket back to the queue
std::unique_lock<std::mutex> lock(_lock); std::unique_lock<std::mutex> lock(_lock);
_tickets.push(ticket); _tickets.push(ticket);

View File

@ -25,14 +25,8 @@
#include <math/templatemath.h> #include <math/templatemath.h>
#include <helpers/shape.h> #include <helpers/shape.h>
#ifdef _OPENMP
#include <omp.h> namespace samediff {
#endif
namespace sd {
int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) { int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) {
// let's see how many threads we actually need first // let's see how many threads we actually need first
@ -57,34 +51,34 @@ namespace sd {
Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) { Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) {
switch (loop) { switch (loop) {
case 1: { case 1: {
auto span = (stopX - startX) / numThreads; auto span = (stopX - startX) / numThreads;
auto s = span * threadID; auto s = span * threadID;
auto e = s + span; auto e = s + span;
if (threadID == numThreads - 1) if (threadID == numThreads - 1)
e = stopX; e = stopX;
return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ); return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ);
} }
break; break;
case 2: { case 2: {
auto span = (stopY - startY) / numThreads; auto span = (stopY - startY) / numThreads;
auto s = span * threadID; auto s = span * threadID;
auto e = s + span; auto e = s + span;
if (threadID == numThreads - 1) if (threadID == numThreads - 1)
e = stopY; e = stopY;
return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ); return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ);
} }
break; break;
case 3: { case 3: {
auto span = (stopZ - startZ) / numThreads; auto span = (stopZ - startZ) / numThreads;
auto s = span * threadID; auto s = span * threadID;
auto e = s + span; auto e = s + span;
if (threadID == numThreads - 1) if (threadID == numThreads - 1)
e = stopZ; e = stopZ;
return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ); return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ);
} }
break; break;
default: default:
throw std::runtime_error(""); throw std::runtime_error("");
@ -122,24 +116,24 @@ namespace sd {
switch (loop) { switch (loop) {
case 1: { case 1: {
auto span = (stopX - startX) / numThreads; auto span = (stopX - startX) / numThreads;
auto s = span * threadID; auto s = span * threadID;
auto e = s + span; auto e = s + span;
if (threadID == numThreads - 1) if (threadID == numThreads - 1)
e = stopX; e = stopX;
return Span2(s, e, incX, startY, stopY, incY); return Span2(s, e, incX, startY, stopY, incY);
} }
break; break;
case 2: { case 2: {
auto span = (stopY - startY) / numThreads; auto span = (stopY - startY) / numThreads;
auto s = span * threadID; auto s = span * threadID;
auto e = s + span; auto e = s + span;
if (threadID == numThreads - 1) if (threadID == numThreads - 1)
e = stopY; e = stopY;
return Span2(startX, stopX, incX, s, e, incY); return Span2(startX, stopX, incX, s, e, incY);
} }
break; break;
default: default:
throw std::runtime_error(""); throw std::runtime_error("");
@ -276,7 +270,7 @@ namespace sd {
auto remY = iters_y % maxThreads; auto remY = iters_y % maxThreads;
// in some cases there's nothing to think about, part 2 // in some cases there's nothing to think about, part 2
if ((iters_x >= maxThreads && remX == 0) || (iters_y >= maxThreads && remY == 0)) if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0))
return maxThreads; return maxThreads;
// at this point we suppose that there's no loop perfectly matches number of our threads // at this point we suppose that there's no loop perfectly matches number of our threads
@ -345,35 +339,11 @@ namespace sd {
return 1; return 1;
} }
#ifdef _OPENMP
std::mutex Threads::gThreadmutex;
uint64_t Threads::_nFreeThreads = sd::Environment::getInstance()->maxThreads();
bool Threads::tryAcquire(int numThreads){
std::lock_guard<std::mutex> lock( gThreadmutex );
auto nThreads = _nFreeThreads - numThreads;
if(nThreads >= 1){
_nFreeThreads = nThreads;
return true;
}
return false;
}
bool Threads::freeThreads(int numThreads){
std::lock_guard<std::mutex> lock( gThreadmutex );
_nFreeThreads += numThreads;
// check if correct number of threads
return _nFreeThreads > sd::Environment::getInstance()->maxThreads();
}
#endif
int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
if (start > stop) if (start > stop)
throw std::runtime_error("Threads::parallel_for got start > stop"); throw std::runtime_error("Threads::parallel_for got start > stop");
auto delta = (stop - start) / increment; auto delta = (stop - start);
if (numThreads > delta) if (numThreads > delta)
numThreads = delta; numThreads = delta;
@ -387,57 +357,35 @@ namespace sd {
return 1; return 1;
} }
#ifdef _OPENMP auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
// if we got our threads - we'll run our jobs here
auto span = delta / numThreads;
if (tryAcquire(numThreads)) { for (uint32_t e = 0; e < numThreads; e++) {
#pragma omp parallel for num_threads(numThreads) auto start_ = span * e + start;
for (auto e = start; e < stop; e += increment) { auto stop_ = start_ + span;
function(omp_get_thread_num(), e, e + 1, 1);
// last thread will process tail
if (e == numThreads - 1)
stop_ = stop;
// putting the task into the queue for a given thread
ticket->enqueue(e, numThreads, function, start_, stop_, increment);
} }
freeThreads(numThreads);
// block and wait till all threads finished the job
ticket->waitAndRelease();
// we tell that parallelism request succeeded
return numThreads; return numThreads;
} } else {
else {
// if there were no threads available - we'll execute function right within current thread // if there were no threads available - we'll execute function right within current thread
function(0, start, stop, increment); function(0, start, stop, increment);
// we tell that parallelism request declined // we tell that parallelism request declined
return 1; return 1;
} }
#else
sd::Environment::getInstance()->maxThreads();
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
// if we got our threads - we'll run our jobs here
auto span = delta / numThreads;
for (uint32_t e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = start_ + span;
// last thread will process tail
if (e == numThreads - 1)
stop_ = stop;
// putting the task into the queue for a given thread
ticket->enqueue(e, numThreads, function, start_, stop_, increment);
}
// block and wait till all threads finished the job
ticket->waitAndRelease();
// we tell that parallelism request succeeded
return numThreads;
}
else {
// if there were no threads available - we'll execute function right within current thread
function(0, start, stop, increment);
// we tell that parallelism request declined
return 1;
}
#endif
} }
int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
@ -500,53 +448,28 @@ namespace sd {
// but we still mimic multithreaded execution // but we still mimic multithreaded execution
return numThreads; return numThreads;
} } else {
else { auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
#ifdef _OPENMP if (ticket != nullptr) {
if (tryAcquire(numThreads)) { for (int e = 0; e < numThreads; e++) {
#pragma omp parallel for num_threads(numThreads) collapse(2) auto threadId = numThreads - e - 1;
for (auto x = startX; x < stopX; x += incX) { auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY);
for (auto y = startY; y < stopY; y += incY) {
function(omp_get_thread_num(), x, x+1, 1, y, y+1, 1); ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
}
} }
freeThreads(numThreads);
// block until all threads finish their job
ticket->waitAndRelease();
return numThreads; return numThreads;
} } else {
else {
// if there were no threads available - we'll execute function right within current thread // if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY); function(0, startX, stopX, incX, startY, stopY, incY);
// we tell that parallelism request declined // we tell that parallelism request declined
return 1; return 1;
} }
#else
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
for (int e = 0; e < numThreads; e++) {
auto threadId = numThreads - e - 1;
auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY);
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
}
// block until all threads finish their job
ticket->waitAndRelease();
return numThreads;
}
else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY);
// we tell that parallelism request declined
return 1;
}
#endif
}; };
} }
@ -561,35 +484,6 @@ namespace sd {
if (startZ > stopZ) if (startZ > stopZ)
throw std::runtime_error("Threads::parallel_for got startZ > stopZ"); throw std::runtime_error("Threads::parallel_for got startZ > stopZ");
if (numThreads == 1) {
// loop is too small - executing function as is
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
return 1;
}
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads) collapse(3)
for (auto x = startX; x < stopX; x += incX) {
for (auto y = startY; y < stopY; y += incY) {
for (auto z = startZ; z < stopZ; z += incZ) {
function(omp_get_thread_num(), x, x+1, 1, y, y+1, 1, z, z+1, 1);
}
}
}
freeThreads(numThreads);
return numThreads;
}
else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
// we tell that parallelism request declined
return 1;
}
#else
auto delta_x = stopX - startX; auto delta_x = stopX - startX;
auto delta_y = stopY - startY; auto delta_y = stopY - startY;
auto delta_z = stopZ - startZ; auto delta_z = stopZ - startZ;
@ -606,79 +500,52 @@ namespace sd {
} }
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads); auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) { if (ticket != nullptr) {
auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ); auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ);
for (int e = 0; e < numThreads; e++) { for (int e = 0; e < numThreads; e++) {
auto thread_id = numThreads - e - 1; auto thread_id = numThreads - e - 1;
auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ()); ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ());
} }
// block until we're done // block until we're done
ticket->waitAndRelease(); ticket->waitAndRelease();
// we tell that parallelism request succeeded // we tell that parallelism request succeeded
return numThreads; return numThreads;
} } else {
else { // if there were no threads available - we'll execute function right within current thread
// if there were no threads available - we'll execute function right within current thread function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
// we tell that parallelism request declined // we tell that parallelism request declined
return 1;
}
#endif
}
int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
if (numThreads == 1) {
function(0, numThreads);
return 1; return 1;
} }
#ifdef _OPENMP }
if (tryAcquire(numThreads)) { int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
#pragma omp parallel for num_threads(numThreads) auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
for (int e = 0; e < numThreads; e++) { if (ticket != nullptr) {
function(e, numThreads);
} // submit tasks one by one
for (uint64_t e = 0; e < numThreads - 1; e++)
ticket->enqueue(e, numThreads, function);
function(numThreads - 1, numThreads);
ticket->waitAndRelease();
freeThreads(numThreads);
return numThreads; return numThreads;
} } else {
else {
// if there's no threads available - we'll execute function sequentially one by one // if there's no threads available - we'll execute function sequentially one by one
for (uint64_t e = 0; e < numThreads; e++) for (uint64_t e = 0; e < numThreads; e++)
function(e, numThreads); function(e, numThreads);
return numThreads; return numThreads;
} }
#else
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket != nullptr) {
// submit tasks one by one
for (uint64_t e = 0; e < numThreads - 1; e++)
ticket->enqueue(e, numThreads, function);
function(numThreads - 1, numThreads);
ticket->waitAndRelease();
return numThreads;
}
else {
// if there's no threads available - we'll execute function sequentially one by one
for (uint64_t e = 0; e < numThreads; e++)
function(e, numThreads);
return numThreads;
}
#endif
return numThreads; return numThreads;
} }
@ -698,44 +565,26 @@ namespace sd {
if (numThreads == 1) if (numThreads == 1)
return function(0, start, stop, increment); return function(0, start, stop, increment);
// create temporary array
int64_t intermediatery[256];
auto span = delta / numThreads;
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (int e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = span * (e + 1) + start;
intermediatery[e] = function(e, start_, e == numThreads - 1 ? stop : stop_, increment);
}
freeThreads(numThreads);
}
else{
// if there were no thre ads available - we'll execute function right within current thread
return function(0, start, stop, increment);
}
#else
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket == nullptr) if (ticket == nullptr)
return function(0, start, stop, increment); return function(0, start, stop, increment);
// execute threads in parallel // create temporary array
for (uint32_t e = 0; e < numThreads; e++) { int64_t intermediatery[256];
auto start_ = span * e + start; auto span = delta / numThreads;
auto stop_ = span * (e + 1) + start;
if (e == numThreads - 1) // execute threads in parallel
intermediatery[e] = function(e, start_, stop, increment); for (uint32_t e = 0; e < numThreads; e++) {
else auto start_ = span * e + start;
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); auto stop_ = span * (e + 1) + start;
}
ticket->waitAndRelease(); if (e == numThreads - 1)
intermediatery[e] = function(e, start_, stop, increment);
else
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
}
#endif ticket->waitAndRelease();
// aggregate results in single thread // aggregate results in single thread
for (uint64_t e = 1; e < numThreads; e++) for (uint64_t e = 1; e < numThreads; e++)
@ -760,47 +609,26 @@ namespace sd {
if (numThreads == 1) if (numThreads == 1)
return function(0, start, stop, increment); return function(0, start, stop, increment);
// create temporary array
double intermediatery[256];
auto span = delta / numThreads;
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (int e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = span * (e + 1) + start;
intermediatery[e] = function(e, start_, e == numThreads - 1 ? stop : stop_, increment);
}
freeThreads(numThreads);
}
else{
// if there were no thre ads available - we'll execute function right within current thread
return function(0, start, stop, increment);
}
#else
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket == nullptr) if (ticket == nullptr)
return function(0, start, stop, increment); return function(0, start, stop, increment);
// execute threads in parallel // create temporary array
for (uint32_t e = 0; e < numThreads; e++) { double intermediatery[256];
auto start_ = span * e + start; auto span = delta / numThreads;
auto stop_ = span * (e + 1) + start;
if (e == numThreads - 1) // execute threads in parallel
intermediatery[e] = function(e, start_, stop, increment); for (uint32_t e = 0; e < numThreads; e++) {
else auto start_ = span * e + start;
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); auto stop_ = span * (e + 1) + start;
}
ticket->waitAndRelease(); if (e == numThreads - 1)
intermediatery[e] = function(e, start_, stop, increment);
else
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
}
#endif ticket->waitAndRelease();
// aggregate results in single thread // aggregate results in single thread
for (uint64_t e = 1; e < numThreads; e++) for (uint64_t e = 1; e < numThreads; e++)
@ -811,7 +639,7 @@ namespace sd {
} }
int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size, uint32_t req_numThreads) { int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size , uint32_t req_numThreads) {
if (start > stop) if (start > stop)
throw std::runtime_error("Threads::parallel_for got start > stop"); throw std::runtime_error("Threads::parallel_for got start > stop");
auto num_elements = (stop - start); auto num_elements = (stop - start);
@ -819,7 +647,6 @@ namespace sd {
//so we will parition considering delta but not total elements //so we will parition considering delta but not total elements
auto delta = (stop - start) / increment; auto delta = (stop - start) / increment;
// in some cases we just fire func as is // in some cases we just fire func as is
if (delta == 0 || req_numThreads == 1) { if (delta == 0 || req_numThreads == 1) {
function(0, start, stop, increment); function(0, start, stop, increment);
@ -827,24 +654,7 @@ namespace sd {
} }
int numThreads = 0; int numThreads = 0;
struct th_span { int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size));
Nd4jLong start;
Nd4jLong end;
};
#ifdef _OPENMP
constexpr int max_thread_count = 8;
#else
constexpr int max_thread_count = 1024;
#endif
th_span thread_spans[max_thread_count];
req_numThreads = req_numThreads > max_thread_count ? max_thread_count : req_numThreads;
#ifdef _OPENMP
int adjusted_numThreads = max_thread_count;
#else
int adjusted_numThreads = sd::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size));
#endif
if (adjusted_numThreads > delta) if (adjusted_numThreads > delta)
adjusted_numThreads = delta; adjusted_numThreads = delta;
@ -853,89 +663,61 @@ namespace sd {
function(0, start, stop, increment); function(0, start, stop, increment);
return 1; return 1;
} }
//take span as ceil
//take span as ceil
auto spand = std::ceil((double)delta / (double)adjusted_numThreads); auto spand = std::ceil((double)delta / (double)adjusted_numThreads);
numThreads = static_cast<int>(std::ceil((double)delta / spand)); numThreads = static_cast<int>(std::ceil((double)delta / spand));
auto span = static_cast<Nd4jLong>(spand); auto span = static_cast<Nd4jLong>(spand);
auto ticket = samediff::ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
//tail_add is additional value of the last part
//it could be negative or positive
//we will spread that value across
auto tail_add = delta - numThreads * span;
Nd4jLong begin = 0;
Nd4jLong end = 0;
//tail_add is additional value of the last part //we will try enqueu bigger parts first
//it could be negative or positive decltype(span) span1, span2;
//we will spread that value across int last = 0;
auto tail_add = delta - numThreads * span; if (tail_add >= 0) {
Nd4jLong begin = 0; //for span == 1 , tail_add is 0
Nd4jLong end = 0; last = tail_add;
span1 = span + 1;
//we will try enqueu bigger parts first span2 = span;
decltype(span) span1, span2;
int last = 0;
if (tail_add >= 0) {
//for span == 1 , tail_add is 0
last = tail_add;
span1 = span + 1;
span2 = span;
}
else {
last = numThreads + tail_add;// -std::abs(tail_add);
span1 = span;
span2 = span - 1;
}
for (int i = 0; i < last; i++) {
end = begin + span1 * increment;
// putting the task into the queue for a given thread
thread_spans[i].start = begin;
thread_spans[i].end = end;
begin = end;
}
for (int i = last; i < numThreads - 1; i++) {
end = begin + span2 * increment;
// putting the task into the queue for a given thread
thread_spans[i].start = begin;
thread_spans[i].end = end;
begin = end;
}
//for last one enqueue last offset as stop
//we need it in case our ((stop-start) % increment ) > 0
thread_spans[numThreads - 1].start = begin;
thread_spans[numThreads - 1].end = stop;
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (size_t j = 0; j < numThreads; j++) {
function(j, thread_spans[j].start, thread_spans[j].end, increment);
} }
freeThreads(numThreads); else {
last = numThreads + tail_add;// -std::abs(tail_add);
span1 = span;
span2 = span - 1;
}
for (int i = 0; i < last; i++) {
end = begin + span1 * increment;
// putting the task into the queue for a given thread
ticket->enqueue(i, numThreads, function, begin, end, increment);
begin = end;
}
for (int i = last; i < numThreads - 1; i++) {
end = begin + span2 * increment;
// putting the task into the queue for a given thread
ticket->enqueue(i, numThreads, function, begin, end, increment);
begin = end;
}
//for last one enqueue last offset as stop
//we need it in case our ((stop-start) % increment ) > 0
ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, increment);
// block and wait till all threads finished the job
ticket->waitAndRelease();
// we tell that parallelism request succeeded
return numThreads; return numThreads;
} }
else { else {
// if there were no threads available - we'll execute function right within current thread
function(0, start, stop, increment); function(0, start, stop, increment);
// we tell that parallelism request declined // we tell that parallelism request declined
return 1; return 1;
} }
#else
auto ticket = sd::ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
for (size_t j = 0; j < numThreads; j++) {
ticket->enqueue(j, numThreads, function, thread_spans[j].start, thread_spans[j].end, increment);
}
// block and wait till all threads finished the job
ticket->waitAndRelease();
// we tell that parallelism request succeeded
return numThreads;
}
else {
// if there were no threads available - we'll execute function right within current thread
function(0, start, stop, increment);
// we tell that parallelism request declined
return 1;
}
#endif
} }
}
}

View File

@ -23,7 +23,7 @@
#include <helpers/logger.h> #include <helpers/logger.h>
#include <array> #include <array>
namespace sd { namespace samediff {
Ticket::Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues) { Ticket::Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues) {
_acquired = true; _acquired = true;
_queues = queues; _queues = queues;
@ -38,7 +38,7 @@ namespace sd {
return _acquired; return _acquired;
} }
void Ticket::enqueue(int thread_id, sd::CallableWithArguments *callable) { void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) {
_queues[thread_id]->put(callable); _queues[thread_id]->put(callable);
_callables.emplace_back(callable); _callables.emplace_back(callable);
} }
@ -88,7 +88,7 @@ namespace sd {
} }
void Ticket::attach(uint32_t thread_id, sd::CallableInterface *interface) { void Ticket::attach(uint32_t thread_id, samediff::CallableInterface *interface) {
_interfaces[thread_id] = interface; _interfaces[thread_id] = interface;
} }
} }

View File

@ -112,7 +112,7 @@ namespace sd {
sd::random::RandomBuffer* getRNG(); sd::random::RandomBuffer* getRNG();
void setRNG(sd::random::RandomBuffer* rng); void setRNG(sd::random::RandomBuffer* rng);
void setTargetEngine(sd::Engine engine); void setTargetEngine(samediff::Engine engine);
VariableSpace *getVariableSpace(); VariableSpace *getVariableSpace();
@ -228,8 +228,8 @@ namespace sd {
void setShapeFunctionOverride(bool reallyOverride); void setShapeFunctionOverride(bool reallyOverride);
bool shapeFunctionOverride(); bool shapeFunctionOverride();
sd::ExecutionMode executionMode(); samediff::ExecutionMode executionMode();
void setExecutionMode(sd::ExecutionMode executionMode); void setExecutionMode(samediff::ExecutionMode executionMode);
bool isTraining(); bool isTraining();
bool isInference(); bool isInference();

View File

@ -64,9 +64,9 @@ namespace sd {
bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN(); bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN();
// target engine for execution // target engine for execution
sd::Engine _engine = DEFAULT_ENGINE; samediff::Engine _engine = DEFAULT_ENGINE;
sd::ExecutionMode _execMode = sd::ExecutionMode::MODE_UNDEFINED; samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED;
public: public:
explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false);
~ContextPrototype() = default; ~ContextPrototype() = default;
@ -99,7 +99,7 @@ namespace sd {
std::vector<sd::DataType>* getDArguments(); std::vector<sd::DataType>* getDArguments();
std::vector<int>* getAxis(); std::vector<int>* getAxis();
sd::Engine engine(); samediff::Engine engine();
size_t numT(); size_t numT();
size_t numI(); size_t numI();

View File

@ -107,7 +107,7 @@ namespace sd {
delete _context; delete _context;
} }
void Context::setTargetEngine(sd::Engine engine) { void Context::setTargetEngine(samediff::Engine engine) {
_engine = engine; _engine = engine;
} }
@ -548,20 +548,20 @@ namespace sd {
return _shapeFunctionOverride; return _shapeFunctionOverride;
} }
sd::ExecutionMode Context::executionMode() { samediff::ExecutionMode Context::executionMode() {
return _execMode; return _execMode;
} }
void Context::setExecutionMode(sd::ExecutionMode executionMode) { void Context::setExecutionMode(samediff::ExecutionMode executionMode) {
_execMode = executionMode; _execMode = executionMode;
} }
bool Context::isTraining() { bool Context::isTraining() {
return _execMode == sd::ExecutionMode::MODE_TRAINING; return _execMode == samediff::ExecutionMode::MODE_TRAINING;
} }
bool Context::isInference() { bool Context::isInference() {
return _execMode == sd::ExecutionMode::MODE_INFERENCE; return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
} }
void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) { void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) {

View File

@ -59,7 +59,7 @@ namespace sd {
} }
} }
sd::Engine ContextPrototype::engine() { samediff::Engine ContextPrototype::engine() {
return _engine; return _engine;
} }

View File

@ -511,7 +511,7 @@ namespace sd {
//*********************************************// //*********************************************//
case LoopKind::EWS1: { case LoopKind::EWS1: {
auto span = sd::Span::build(threadId, numThreads, 0, len, 1); auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
int64_t start = span.startX(), stop = span.stopX(); int64_t start = span.startX(), stop = span.stopX();
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
@ -524,7 +524,7 @@ namespace sd {
const uint xEws = shape::elementWiseStride(xShapeInfo); const uint xEws = shape::elementWiseStride(xShapeInfo);
const uint zEws = shape::elementWiseStride(zShapeInfo); const uint zEws = shape::elementWiseStride(zShapeInfo);
auto span = sd::Span::build(threadId, numThreads, 0, len, 1); auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
int64_t start = span.startX(), stop = span.stopX(); int64_t start = span.startX(), stop = span.stopX();
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
@ -538,7 +538,7 @@ namespace sd {
uint castXShapeInfo[MAX_RANK]; uint castXShapeInfo[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, castXShapeInfo); const bool canCastX = sd::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, castXShapeInfo);
auto span = sd::Span::build(threadId, numThreads, 0, len, 1); auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
int64_t start = span.startX(), stop = span.stopX(); int64_t start = span.startX(), stop = span.stopX();
if (zEws > 1) { if (zEws > 1) {
@ -558,7 +558,7 @@ namespace sd {
//*********************************************// //*********************************************//
case LoopKind::RANK1: { case LoopKind::RANK1: {
auto span = sd::Span::build(threadId, numThreads, 0, len, 1); auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams); z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams);
@ -570,8 +570,8 @@ namespace sd {
auto uXShape0 = static_cast<uint>(xShape[0]); auto uXShape0 = static_cast<uint>(xShape[0]);
auto uXShape1 = static_cast<uint>(xShape[1]); auto uXShape1 = static_cast<uint>(xShape[1]);
auto loop = sd::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
auto span = sd::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1); auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) { for (auto i0 = span.startX(); i0 < span.stopX(); i0++) {
auto z0 = i0 * zStride[0]; auto z0 = i0 * zStride[0];
@ -589,8 +589,8 @@ namespace sd {
auto uXShape1 = xShape[1]; auto uXShape1 = xShape[1];
auto uXShape2 = xShape[2]; auto uXShape2 = xShape[2];
auto loop = sd::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
auto span = sd::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1); auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
@ -611,8 +611,8 @@ namespace sd {
auto uXShape2 = xShape[2]; auto uXShape2 = xShape[2];
auto uXShape3 = xShape[3]; auto uXShape3 = xShape[3];
auto loop = sd::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2); auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
auto span = sd::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1); auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
for (auto i1 = span.startY(); i1 < span.stopY(); i1++) for (auto i1 = span.startY(); i1 < span.stopY(); i1++)
@ -634,8 +634,8 @@ namespace sd {
auto uXShape3 = xShape[3]; auto uXShape3 = xShape[3];
auto uXShape4 = xShape[4]; auto uXShape4 = xShape[4];
auto loop = sd::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2); auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
auto span = sd::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1); auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
@ -666,7 +666,7 @@ namespace sd {
bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
auto span = sd::Span::build(threadId, numThreads, 0, len, 1); auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
for (auto i = span.startX(); i < span.stopX(); i++) { for (auto i = span.startX(); i < span.stopX(); i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);

View File

@ -93,7 +93,7 @@ static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
} }
}; };
sd::Threads::parallel_tad(func, 0, cLen); samediff::Threads::parallel_tad(func, 0, cLen);
} }
@ -146,7 +146,7 @@ static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const
} }
}; };
sd::Threads::parallel_tad(func, 0, M); samediff::Threads::parallel_tad(func, 0, M);
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -477,7 +477,7 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
} }
}; };
sd::Threads::parallel_tad(func, 0, cLen); samediff::Threads::parallel_tad(func, 0, cLen);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -669,7 +669,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
} }
}; };
sd::Threads::parallel_tad(func, 0, M, 1, 0, N, 1); samediff::Threads::parallel_tad(func, 0, M, 1, 0, N, 1);
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -703,7 +703,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
} }
}; };
sd::Threads::parallel_tad(func, 0, M); samediff::Threads::parallel_tad(func, 0, M);
} }
*/ */

View File

@ -62,7 +62,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -83,7 +83,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -104,7 +104,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -131,7 +131,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -160,7 +160,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -191,7 +191,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -224,7 +224,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -248,7 +248,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -272,7 +272,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -299,7 +299,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
} }
} }

View File

@ -99,7 +99,7 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_m
return _stdDevValue; return _stdDevValue;
}; };
_stdDevValue = sd::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf()); _stdDevValue = samediff::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf());
info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf()); info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf());

View File

@ -199,7 +199,7 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
} }
} }
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
#endif #endif
} }
@ -237,7 +237,7 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo); auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen; auto numTads = yLen / xLen;
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
#endif #endif
} }
@ -273,7 +273,7 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo); auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen; auto numTads = xLen / yLen;
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
} }
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
@ -308,7 +308,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo); auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen; auto numTads = yLen / xLen;
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
} }
@ -348,7 +348,7 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo); auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen; auto numTads = xLen / yLen;
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
} }
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
@ -384,7 +384,7 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo); auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen; auto numTads = yLen / xLen;
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -427,7 +427,7 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
}; };
auto zLen = shape::length(hZShapeInfo); auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
#endif #endif
} }
@ -462,7 +462,7 @@ void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
}; };
auto zLen = shape::length(hZShapeInfo); auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
@ -495,7 +495,7 @@ void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
}; };
auto zLen = shape::length(hZShapeInfo); auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
@ -534,7 +534,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -562,7 +562,7 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -590,7 +590,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -618,7 +618,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -791,7 +791,7 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
}; };
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
} }
@ -820,7 +820,7 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
}; };
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -861,7 +861,7 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
}; };
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
} }
@ -905,7 +905,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
}; };
auto zLen = shape::length(hZShapeInfo); auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
#endif #endif
} }
@ -942,7 +942,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
}; };
auto yLen = shape::length(hScalarShapeInfo); auto yLen = shape::length(hScalarShapeInfo);
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads())); samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
#endif #endif
} }
@ -976,7 +976,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
}; };
auto zLen = shape::length(hZShapeInfo); auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
@ -1012,7 +1012,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
}; };
auto yLen = shape::length(hScalarShapeInfo); auto yLen = shape::length(hScalarShapeInfo);
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads())); samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1044,7 +1044,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
}; };
auto zLen = shape::length(hZShapeInfo); auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
@ -1080,7 +1080,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
}; };
auto yLen = shape::length(hScalarShapeInfo); auto yLen = shape::length(hScalarShapeInfo);
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads())); samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1193,7 +1193,7 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
}; };
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1215,7 +1215,7 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
}; };
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1243,7 +1243,7 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
}; };
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
} }
@ -1266,7 +1266,7 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
}; };
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1288,7 +1288,7 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
}; };
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////

View File

@ -1318,7 +1318,7 @@ void pullRowsGeneric(void *vx,
} }
}; };
sd::Threads::parallel_tad(func, 0, n, 1, _threads); samediff::Threads::parallel_tad(func, 0, n, 1, _threads);
} }
void pullRows(Nd4jPointer *extraPointers, void pullRows(Nd4jPointer *extraPointers,
@ -1377,7 +1377,7 @@ void tearGeneric(void *vx,
} }
}; };
sd::Threads::parallel_tad(func,0, numTads); samediff::Threads::parallel_tad(func,0, numTads);
} }
void tear(Nd4jPointer *extraPointers, void tear(Nd4jPointer *extraPointers,
@ -1530,7 +1530,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
} }
}; };
sd::Threads::parallel_tad(func, 0, N); samediff::Threads::parallel_tad(func, 0, N);
} }
void shuffle(Nd4jPointer *extras, void shuffle(Nd4jPointer *extras,
@ -1944,7 +1944,7 @@ FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer
return cnt; return cnt;
}; };
return sd::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N); return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
} }
@ -2653,7 +2653,7 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
} }
}; };
sd::Threads::parallel_do(func); samediff::Threads::parallel_do(func);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -2812,7 +2812,7 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
if (execMode < 0 || execMode > 2) if (execMode < 0 || execMode > 2)
execMode = 0; execMode = 0;
ptr->setExecutionMode((sd::ExecutionMode) execMode); ptr->setExecutionMode((samediff::ExecutionMode) execMode);
} }
void ctxPurge(OpaqueContext* ptr) { void ctxPurge(OpaqueContext* ptr) {

View File

@ -3799,7 +3799,7 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
if (execMode < 0 || execMode > 2) if (execMode < 0 || execMode > 2)
execMode = 0; execMode = 0;
ptr->setExecutionMode((sd::ExecutionMode) execMode); ptr->setExecutionMode((samediff::ExecutionMode) execMode);
} }
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {

View File

@ -60,7 +60,7 @@ namespace sd {
} }
} }
}; };
sd::Threads::parallel_tad(func, 0, xArr.lengthOf()); samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return; return;
} }
@ -95,7 +95,7 @@ namespace sd {
} }
} }
}; };
sd::Threads::parallel_tad(func, 0, nLen, 1); samediff::Threads::parallel_tad(func, 0, nLen, 1);
return; return;
} }
@ -137,7 +137,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
@ -200,7 +200,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
template <typename X, typename Y> template <typename X, typename Y>
@ -263,7 +263,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
template <typename X> template <typename X>

View File

@ -79,7 +79,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, len, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams); startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
@ -95,7 +95,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, len, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams); startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);

View File

@ -67,7 +67,7 @@ namespace functions {
z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments); z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments);
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
else{ else{
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
@ -81,7 +81,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
} }
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
@ -100,7 +100,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
@ -118,7 +118,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
@ -136,7 +136,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
else { else {
@ -157,7 +157,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
}; };
@ -192,7 +192,7 @@ namespace functions {
z[i] = OpClass::op(x[i], i, length, rng, extraArguments); z[i] = OpClass::op(x[i], i, length, rng, extraArguments);
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
else{ else{
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
@ -203,7 +203,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
} }
else { else {
@ -220,7 +220,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
} }
@ -245,7 +245,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
else{ else{
sd::OmpLaunchHelper info(length); sd::OmpLaunchHelper info(length);
@ -261,7 +261,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
} }

View File

@ -208,26 +208,13 @@ namespace functions {
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++) for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x); intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) { if (xEws == 1) {
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
@ -238,9 +225,7 @@ namespace functions {
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
// merge results // merge results
for (int e = 1; e < maxThreads; e++) for (int e = 1; e < maxThreads; e++)

View File

@ -72,7 +72,7 @@ namespace functions {
auto startingValue = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
@ -84,7 +84,7 @@ namespace functions {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results // merge results
for (int e = 1; e < maxThreads; e++) for (int e = 1; e < maxThreads; e++)
@ -242,27 +242,13 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++) for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x); intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) { if (xEws == 1) {
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
@ -273,9 +259,7 @@ namespace functions {
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
// merge results // merge results
for (int e = 1; e < maxThreads; e++) for (int e = 1; e < maxThreads; e++)

View File

@ -67,7 +67,7 @@ namespace functions {
auto startingValue = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
@ -79,7 +79,7 @@ namespace functions {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results // merge results
for (int e = 1; e < maxThreads; e++) for (int e = 1; e < maxThreads; e++)
@ -231,26 +231,13 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64]; Z intermediate[64];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++) for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x); intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) { if (xEws == 1) {
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
@ -261,9 +248,7 @@ namespace functions {
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
// merge results // merge results
for (int e = 1; e < maxThreads; e++) for (int e = 1; e < maxThreads; e++)

View File

@ -69,7 +69,7 @@ namespace functions {
auto startingValue = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
X intermediate[64]; X intermediate[64];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
@ -81,7 +81,7 @@ namespace functions {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results // merge results
for (int e = 1; e < maxThreads; e++) for (int e = 1; e < maxThreads; e++)
@ -240,26 +240,13 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads()); int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
X intermediate[64]; X intermediate[64];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++) for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x); intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) { if (xEws == 1) {
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
@ -270,9 +257,7 @@ namespace functions {
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
// merge results // merge results
for (int e = 1; e < maxThreads; e++) for (int e = 1; e < maxThreads; e++)

View File

@ -93,7 +93,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { } else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
@ -104,7 +104,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} else { } else {
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
@ -117,7 +117,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads); maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} }
// merge step // merge step

View File

@ -187,7 +187,7 @@ namespace functions {
} }
}; };
sd::Threads::parallel_tad(func, 0, resultLength, 1); samediff::Threads::parallel_tad(func, 0, resultLength, 1);
} }

View File

@ -86,7 +86,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, N); samediff::Threads::parallel_for(func, 0, N);
} }
template <typename T> template <typename T>
@ -184,7 +184,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
} }
}; };
sd::Threads::parallel_for(func, 4, flimit); samediff::Threads::parallel_for(func, 4, flimit);
} }
/** /**
@ -206,7 +206,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
z[i] = static_cast<T>(static_cast<float>(x[i])); z[i] = static_cast<T>(static_cast<float>(x[i]));
} }
}; };
sd::Threads::parallel_for(func, 0, N); samediff::Threads::parallel_for(func, 0, N);
}; };
template void TypeCast::convertFromThreshold<float>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); template void TypeCast::convertFromThreshold<float>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);

View File

@ -112,7 +112,7 @@ namespace sd {
*/ */
int prepareOutputs(Context& block); int prepareOutputs(Context& block);
virtual sd::EmptyHandling emptyHandling(); virtual samediff::EmptyHandling emptyHandling();
public: public:
// for special cases, like BooleanOps // for special cases, like BooleanOps
DeclarableOp(); DeclarableOp();

View File

@ -21,7 +21,7 @@
#ifndef SAMEDIFF_EMPTYHANDLING_H #ifndef SAMEDIFF_EMPTYHANDLING_H
#define SAMEDIFF_EMPTYHANDLING_H #define SAMEDIFF_EMPTYHANDLING_H
namespace sd { namespace samediff {
enum EmptyHandling { enum EmptyHandling {
EMPTY_SKIP = 1, EMPTY_SKIP = 1,
EMPTY_EXCEPTION = 2, EMPTY_EXCEPTION = 2,

View File

@ -38,15 +38,15 @@
namespace std { namespace std {
template <> template <>
class hash<std::pair<Nd4jLong, sd::Engine>> { class hash<std::pair<Nd4jLong, samediff::Engine>> {
public: public:
size_t operator()(const std::pair<Nd4jLong, sd::Engine>& k) const; size_t operator()(const std::pair<Nd4jLong, samediff::Engine>& k) const;
}; };
template <> template <>
class hash<std::pair<std::string, sd::Engine>> { class hash<std::pair<std::string, samediff::Engine>> {
public: public:
size_t operator()(const std::pair<std::string, sd::Engine>& k) const; size_t operator()(const std::pair<std::string, samediff::Engine>& k) const;
}; };
}; };
@ -87,8 +87,8 @@ namespace sd {
std::vector<sd::ops::DeclarableOp *> _uniqueD; std::vector<sd::ops::DeclarableOp *> _uniqueD;
// pointers to platform-specific helpers // pointers to platform-specific helpers
MAP_IMPL<std::pair<Nd4jLong, sd::Engine>, sd::ops::platforms::PlatformHelper*> _helpersLH; MAP_IMPL<std::pair<Nd4jLong, samediff::Engine>, sd::ops::platforms::PlatformHelper*> _helpersLH;
MAP_IMPL<std::pair<std::string, sd::Engine>, sd::ops::platforms::PlatformHelper*> _helpersH; MAP_IMPL<std::pair<std::string, samediff::Engine>, sd::ops::platforms::PlatformHelper*> _helpersH;
std::vector<sd::ops::platforms::PlatformHelper*> _uniqueH; std::vector<sd::ops::platforms::PlatformHelper*> _uniqueH;
std::mutex _locker; std::mutex _locker;
@ -119,13 +119,13 @@ namespace sd {
void registerHelper(sd::ops::platforms::PlatformHelper* op); void registerHelper(sd::ops::platforms::PlatformHelper* op);
bool hasHelper(Nd4jLong hash, sd::Engine engine); bool hasHelper(Nd4jLong hash, samediff::Engine engine);
sd::ops::DeclarableOp* getOperation(const char *name); sd::ops::DeclarableOp* getOperation(const char *name);
sd::ops::DeclarableOp* getOperation(Nd4jLong hash); sd::ops::DeclarableOp* getOperation(Nd4jLong hash);
sd::ops::DeclarableOp* getOperation(std::string &name); sd::ops::DeclarableOp* getOperation(std::string &name);
sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, sd::Engine engine); sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine);
std::vector<Nd4jLong> getAllHashes(); std::vector<Nd4jLong> getAllHashes();

View File

@ -37,7 +37,7 @@ namespace sd {
class ND4J_EXPORT PlatformHelper { class ND4J_EXPORT PlatformHelper {
protected: protected:
// target engine for this impl // target engine for this impl
sd::Engine _engine; samediff::Engine _engine;
// name of the operation this helper is built for // name of the operation this helper is built for
std::string _name; std::string _name;
@ -45,13 +45,13 @@ namespace sd {
// hash of the operation this helper is built for // hash of the operation this helper is built for
Nd4jLong _hash; Nd4jLong _hash;
public: public:
PlatformHelper(const char *name, sd::Engine engine); PlatformHelper(const char *name, samediff::Engine engine);
~PlatformHelper() = default; ~PlatformHelper() = default;
std::string name(); std::string name();
sd::Engine engine(); samediff::Engine engine();
Nd4jLong hash(); Nd4jLong hash();

View File

@ -174,7 +174,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(func, 0, N); samediff::Threads::parallel_tad(func, 0, N);
} }
void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) { void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) {

View File

@ -154,7 +154,7 @@ void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alp
} }
}; };
sd::Threads::parallel_for(func, 0, inputLen); samediff::Threads::parallel_for(func, 0, inputLen);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -565,7 +565,7 @@ namespace sd {
} }
}; };
// //
sd::Threads::parallel_aligned_increment(func, 0, total_num, inc); samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc);
} }
else { else {
//NC...HW case here //NC...HW case here
@ -631,7 +631,7 @@ namespace sd {
} }
}; };
// //
sd::Threads::parallel_aligned_increment(func, 0, total_num, inc); samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc);
} }
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -55,7 +55,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3); samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
} }
else { else {
@ -87,7 +87,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
} }

View File

@ -56,7 +56,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3); samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
} else { } else {
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
@ -84,7 +84,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
} }

View File

@ -114,7 +114,7 @@ void bgemm_(const std::vector<NDArray*>& vA, const std::vector<NDArray*>& vB, st
} }
}; };
sd::Threads::parallel_tad(func, 0, vaSize); samediff::Threads::parallel_tad(func, 0, vaSize);
} }
} }

View File

@ -106,7 +106,7 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
delete []zOffsets; delete []zOffsets;
}; };
sd::Threads::parallel_do(func, info._numThreads); samediff::Threads::parallel_do(func, info._numThreads);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -178,7 +178,7 @@ static void batchnorm2_(const NDArray* input, const NDArray* mean, const NDArray
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf()); samediff::Threads::parallel_for(func, 0, input->lengthOf());
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -121,7 +121,7 @@ static void betaIncForArray(sd::LaunchContext * context, const NDArray& a, const
output.t<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i)); output.t<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i));
}; };
sd::Threads::parallel_for(func, 0, xLen); samediff::Threads::parallel_for(func, 0, xLen);
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////

View File

@ -89,7 +89,7 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
else { else {
@ -127,7 +127,7 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
} }
}; };
sd::Threads::parallel_tad(func, 0, bS); samediff::Threads::parallel_tad(func, 0, bS);
} }
} }

View File

@ -40,7 +40,7 @@ namespace sd {
} }
return sum; return sum;
}; };
sumt = sd::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
} else { } else {
//PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) //PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum)
auto func = PRAGMA_REDUCE_LONG { auto func = PRAGMA_REDUCE_LONG {
@ -53,7 +53,7 @@ namespace sd {
return sum; return sum;
}; };
sumt = sd::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
} }
//nd4j_printf("Sum: %lld\n", sumt) //nd4j_printf("Sum: %lld\n", sumt)

View File

@ -40,7 +40,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, lLen); samediff::Threads::parallel_for(func, 0, lLen);
} }
void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {

View File

@ -101,7 +101,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1);
} else { } else {
@ -139,7 +139,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1);
//func(0, 0, bS, 1, 0, oD, 1); //func(0, 0, bS, 1, 0, oD, 1);
} }
} }
@ -215,7 +215,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, bS); samediff::Threads::parallel_tad(func, 0, bS);
} else { } else {
@ -251,7 +251,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, bS); samediff::Threads::parallel_tad(func, 0, bS);
} }
} }
@ -606,7 +606,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -663,7 +663,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -716,7 +716,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -777,7 +777,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -860,7 +860,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 1) { // avg else if(poolingMode == 1) { // avg
@ -914,7 +914,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 2) { // pnorm else if(poolingMode == 2) { // pnorm
@ -963,7 +963,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
else { else {
nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
@ -1068,7 +1068,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 1) { // avg else if(poolingMode == 1) { // avg
@ -1131,7 +1131,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 2) { // pnorm else if(poolingMode == 2) { // pnorm
@ -1191,7 +1191,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
} }
else { else {
nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
@ -1321,7 +1321,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 1) { // avg else if(poolingMode == 1) { // avg
@ -1379,7 +1379,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 2) { // pnorm else if(poolingMode == 2) { // pnorm
@ -1466,7 +1466,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
else { else {
nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
@ -1618,7 +1618,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 1) { // avg else if(poolingMode == 1) { // avg
@ -1679,7 +1679,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
} }
/*************************************************************************/ /*************************************************************************/
else if(poolingMode == 2) { // pnorm else if(poolingMode == 2) { // pnorm
@ -1761,7 +1761,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
} }
else { else {
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);

View File

@ -115,7 +115,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_for(func, 0, cropHeight); samediff::Threads::parallel_for(func, 0, cropHeight);
} }
} }
} }

View File

@ -48,7 +48,7 @@ void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *
} }
}; };
sd::Threads::parallel_tad(func, 0, tads); samediff::Threads::parallel_tad(func, 0, tads);
} }
} }

View File

@ -65,7 +65,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, total_count); samediff::Threads::parallel_for(func, 0, total_count);
} else { } else {
const int total_count = batch_size * input_depth_by_input_area; const int total_count = batch_size * input_depth_by_input_area;
@ -89,7 +89,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, total_count); samediff::Threads::parallel_for(func, 0, total_count);
} }
} }

View File

@ -35,7 +35,7 @@ static void diGamma_(const NDArray& x, NDArray& z) {
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
z.p(i, diGammaScalar<T>(x.e<T>(i))); z.p(i, diGammaScalar<T>(x.e<T>(i)));
}; };
sd::Threads::parallel_for(func, 0, x.lengthOf()); samediff::Threads::parallel_for(func, 0, x.lengthOf());
} }
void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) { void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) {

View File

@ -87,7 +87,7 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
} }
void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {

View File

@ -43,7 +43,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, inLen); samediff::Threads::parallel_for(func, 0, inLen);
} }
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
@ -137,7 +137,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf()); samediff::Threads::parallel_for(func, 0, input->lengthOf());
return Status::OK(); return Status::OK();
} }

View File

@ -71,7 +71,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, outSize); samediff::Threads::parallel_tad(func, 0, outSize);
} }
} }
template <typename T> template <typename T>
@ -177,7 +177,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, gradsSize); samediff::Threads::parallel_tad(func, 0, gradsSize);
} }
outputList[1]->assign(indices); outputList[1]->assign(indices);

View File

@ -82,7 +82,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(func, 0, batchCount); samediff::Threads::parallel_tad(func, 0, batchCount);
} }

View File

@ -63,7 +63,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
output->p(i, input->e(indices->e<Nd4jLong>(i))); output->p(i, input->e(indices->e<Nd4jLong>(i)));
}; };
sd::Threads::parallel_for(func, 0, output->lengthOf()); samediff::Threads::parallel_for(func, 0, output->lengthOf());
} }
else { else {
@ -96,7 +96,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
else { else {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
@ -112,7 +112,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
} }
} }
@ -148,7 +148,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
else { else {
@ -167,7 +167,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
} }

View File

@ -64,7 +64,7 @@ namespace sd {
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf); maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) { } else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) { for (auto e = start; e < stop; e++) {
@ -75,7 +75,7 @@ namespace sd {
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf); maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
} else { } else {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) { for (auto e = start; e < stop; e++) {
@ -86,7 +86,7 @@ namespace sd {
} }
}; };
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf); maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
} }
// accumulate intermediate variables into output array // accumulate intermediate variables into output array

View File

@ -54,7 +54,7 @@ namespace sd {
tempBuffer[b] = r; tempBuffer[b] = r;
} }
}; };
sd::Threads::parallel_tad(func, 0, numBlocks); samediff::Threads::parallel_tad(func, 0, numBlocks);
// we replace pointer with intermediate one, and repeat only one chunk left // we replace pointer with intermediate one, and repeat only one chunk left
int iterationCount = 0; int iterationCount = 0;
@ -76,7 +76,7 @@ namespace sd {
tempResult[b] = r; tempResult[b] = r;
} }
}; };
sd::Threads::parallel_tad(func2, 0, numBlocks); samediff::Threads::parallel_tad(func2, 0, numBlocks);
iterationCount++; iterationCount++;

View File

@ -90,7 +90,7 @@ static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray&
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
} }
else { else {
@ -124,7 +124,7 @@ static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray&
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
} }
} }

View File

@ -149,7 +149,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_tad(func, 0, batchSize); samediff::Threads::parallel_tad(func, 0, batchSize);
} }

View File

@ -178,7 +178,7 @@ namespace helpers {
interpolationData[i]._interpolarValue = in - in_f; interpolationData[i]._interpolarValue = in - in_f;
} }
}; };
sd::Threads::parallel_for(func, 0, outSize); samediff::Threads::parallel_for(func, 0, outSize);
} }
/** /**
@ -240,7 +240,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_tad(func, 0, batchSize); samediff::Threads::parallel_tad(func, 0, batchSize);
} }
template<typename X, typename Z> template<typename X, typename Z>
@ -285,7 +285,7 @@ namespace helpers {
xs[i]._topIndex *= channels; xs[i]._topIndex *= channels;
} }
}; };
sd::Threads::parallel_for(func, 0, xsSize); samediff::Threads::parallel_for(func, 0, xsSize);
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>()); resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
return Status::OK(); return Status::OK();
@ -323,7 +323,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1); samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
} }
template<typename T> template<typename T>
@ -427,7 +427,7 @@ namespace helpers {
coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
} }
}; };
sd::Threads::parallel_for(func, 0, kTableSize); samediff::Threads::parallel_for(func, 0, kTableSize);
return coeffs_table; return coeffs_table;
} }
@ -541,7 +541,7 @@ namespace helpers {
x_wai._index3); x_wai._index3);
} }
}; };
sd::Threads::parallel_for(func, 0, resizer_state.outWidth); samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
} else { } else {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto x = start; x < stop; ++x) { for (auto x = start; x < stop; ++x) {
@ -552,7 +552,7 @@ namespace helpers {
x_wai._index3); x_wai._index3);
} }
}; };
sd::Threads::parallel_for(func, 0, resizer_state.outWidth); samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
} }
// Scale the values so they can be used as offsets into buffers. // Scale the values so they can be used as offsets into buffers.
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
@ -563,7 +563,7 @@ namespace helpers {
(*x_wais)[x]._index3 *= resizer_state.channels; (*x_wais)[x]._index3 *= resizer_state.channels;
} }
}; };
sd::Threads::parallel_for(func, 0, resizer_state.outWidth); samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
} }
template <typename T> template <typename T>
@ -774,7 +774,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_tad(func, 0, batchNum); samediff::Threads::parallel_tad(func, 0, batchNum);
} }
// simplified bicubic resize without antialiasing // simplified bicubic resize without antialiasing
@ -950,7 +950,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1); samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1);
} }
template <typename X> template <typename X>
@ -981,7 +981,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1); samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1);
resizeArea<X>(st, xCached, image, output); resizeArea<X>(st, xCached, image, output);
} }

View File

@ -45,7 +45,7 @@ static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) {
} }
}; };
sd::Threads::parallel_for(func, 0, output.lengthOf(), 1); samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
return; return;
} }
@ -62,7 +62,7 @@ static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) {
} }
}; };
sd::Threads::parallel_for(func, 0, output.lengthOf(), 1); samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
return; return;
} }
@ -87,7 +87,7 @@ FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, con
} }
}; };
sd::Threads::parallel_for(func, 0, input.lengthOf(), 3); samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3);
return; return;
} }
@ -106,7 +106,7 @@ FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, con
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
return; return;
} }
@ -146,7 +146,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3); samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
} }
else { else {
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
@ -165,7 +165,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
} }
@ -196,7 +196,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3); samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
} }
else { else {
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
@ -222,7 +222,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
} }

View File

@ -195,7 +195,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
} }
}; };
sd::Threads::parallel_tad(func, 0, tads); samediff::Threads::parallel_tad(func, 0, tads);
} }
} }

View File

@ -96,7 +96,7 @@ static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* outpu
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
else { else {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
@ -134,7 +134,7 @@ static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* outpu
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
return Status::OK(); return Status::OK();
} }
@ -242,7 +242,7 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
else { else {
@ -317,7 +317,7 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfTads); samediff::Threads::parallel_tad(func, 0, numOfTads);
} }
gradI *= gradO; gradI *= gradO;
} }

View File

@ -130,7 +130,7 @@ static void fusedTanh(NDArray *z, NDArray *i, NDArray *c, const NDArray *cLast,
} }
}; };
sd::Threads::parallel_for(func, 0, uLen); samediff::Threads::parallel_for(func, 0, uLen);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -54,7 +54,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(loop, 0, n, 1); samediff::Threads::parallel_tad(loop, 0, n, 1);
} }
} }
@ -79,8 +79,8 @@ namespace helpers {
invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i)); invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
}; };
sd::Threads::parallel_for(invertDiagonals, 0, n, 1); samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
sd::Threads::parallel_for(invertSubDiagonals, 1, n, 1); samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1);
// PRAGMA_OMP_PARALLEL_FOR_SIMD // PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int i = 1; i < n; i++) { for (int i = 1; i < n; i++) {
@ -118,8 +118,8 @@ namespace helpers {
inputMatrix->t<T>(i, i)); inputMatrix->t<T>(i, i));
}; };
sd::Threads::parallel_for(invertDiagonals, 0, n, 1); samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
sd::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1); samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1);
// PRAGMA_OMP_PARALLEL_FOR_SIMD // PRAGMA_OMP_PARALLEL_FOR_SIMD
for (auto i = n - 2; i >= 0; i--) { for (auto i = n - 2; i >= 0; i--) {
@ -225,7 +225,7 @@ namespace helpers {
} }
} }
//}; //};
//sd::Threads::parallel_for(loop, column, rowNum, 1); //samediff::Threads::parallel_for(loop, column, rowNum, 1);
return result; return result;
} }
@ -247,7 +247,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
} }
template <typename T> template <typename T>
@ -327,7 +327,7 @@ namespace helpers {
luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n); luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
} }
}; };
sd::Threads::parallel_for(loop, 0, outputs.size(), 1); samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
} }
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {

View File

@ -63,7 +63,7 @@ void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& outp
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset]; z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
} }
}; };
sd::Threads::parallel_for(func, 0, xLen); samediff::Threads::parallel_for(func, 0, xLen);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -51,7 +51,7 @@ int _matrixDiagPart(const NDArray* input, NDArray* output) {
listOut.at(i)->p(j, listDiag.at(i)->e<T>(j, j)); listOut.at(i)->p(j, listDiag.at(i)->e<T>(j, j));
}; };
sd::Threads::parallel_tad(func, 0, lO); samediff::Threads::parallel_tad(func, 0, lO);
return Status::OK(); return Status::OK();
} }

View File

@ -61,7 +61,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, oL); samediff::Threads::parallel_for(func, 0, oL);
} }
} }

View File

@ -67,7 +67,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
} else { } else {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) { for (auto e = start; e < stop; e++) {
@ -88,7 +88,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, numTads); samediff::Threads::parallel_tad(func, 0, numTads);
} }
} }

View File

@ -80,7 +80,7 @@ static void polyGamma_(sd::LaunchContext * context, const NDArray& n, const NDAr
output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i))); output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i)));
} }
}; };
sd::Threads::parallel_for(func, 0, x.lengthOf()); samediff::Threads::parallel_for(func, 0, x.lengthOf());
} }
void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {

View File

@ -48,7 +48,7 @@ namespace helpers {
resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0)); resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0));
}; };
sd::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1); samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1);
return res; return res;
} }
@ -119,7 +119,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(batching, 0, listOutQ.size(), 1); samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1);
} }

View File

@ -197,7 +197,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1); samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
rng.rewindH(output.lengthOf()*numOfClassX); rng.rewindH(output.lengthOf()*numOfClassX);
return; return;

View File

@ -42,7 +42,7 @@ static void _range(const NDArray& start, const NDArray& delta, NDArray& outVecto
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
buff[i] = s + i * d; buff[i] = s + i * d;
}; };
sd::Threads::parallel_for(func, 0, len); samediff::Threads::parallel_for(func, 0, len);
} }
void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) {

View File

@ -59,7 +59,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
swap(inArr, e, idx); swap(inArr, e, idx);
} }
}; };
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
} }
else if (inEWS > 1) { else if (inEWS > 1) {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
@ -70,7 +70,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
} }
}; };
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
} }
else { else {
@ -82,7 +82,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
} }
}; };
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
} }
} }
else { else {
@ -96,14 +96,14 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
for (Nd4jLong e = start; e < stop; e++) for (Nd4jLong e = start; e < stop; e++)
outArr[sLength - e] = inArr[e]; outArr[sLength - e] = inArr[e];
}; };
sd::Threads::parallel_for(func, 0, numOfElemsToReverse); samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
if(inLength != numOfElemsToReverse) { if(inLength != numOfElemsToReverse) {
auto f2 = PRAGMA_THREADS_FOR { auto f2 = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) for (auto e = start; e < stop; e++)
outArr[e] = inArr[e]; outArr[e] = inArr[e];
}; };
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength); samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
} }
} }
else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) { else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) {
@ -112,14 +112,14 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
for (auto e = start; e < stop; e++) for (auto e = start; e < stop; e++)
outArr[(sLength - e) * outEWS] = inArr[e * inEWS]; outArr[(sLength - e) * outEWS] = inArr[e * inEWS];
}; };
sd::Threads::parallel_for(func, 0, numOfElemsToReverse); samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
if(inLength != numOfElemsToReverse) { if(inLength != numOfElemsToReverse) {
auto f2 = PRAGMA_THREADS_FOR { auto f2 = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) for (auto e = start; e < stop; e++)
outArr[e * outEWS] = inArr[e * inEWS]; outArr[e * outEWS] = inArr[e * inEWS];
}; };
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength); samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
} }
} }
else { else {
@ -131,7 +131,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
outArr[outOffset] = inArr[inOffset]; outArr[outOffset] = inArr[inOffset];
} }
}; };
sd::Threads::parallel_for(func, 0, numOfElemsToReverse); samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
if(inLength != numOfElemsToReverse) { if(inLength != numOfElemsToReverse) {
@ -142,7 +142,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
outArr[outOffset] = inArr[inOffset]; outArr[outOffset] = inArr[inOffset];
} }
}; };
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength); samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
} }
} }
} }

View File

@ -69,7 +69,7 @@ static void batchToSpace_(const NDArray& input, NDArray& output, const uint crop
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1); samediff::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1);
} }
BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES);
@ -128,7 +128,7 @@ static void batchToSpaceND_(const NDArray& input, const NDArray& crop, NDArray&
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, (const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, (const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);
@ -234,7 +234,7 @@ static void spaceToBatch_(const NDArray& input, NDArray& output, const uint padB
} }
}; };
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
} }
BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES);
@ -327,7 +327,7 @@ static void spaceToBatchND_(const NDArray& input, const NDArray& padding, NDArra
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, (const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, (const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);

View File

@ -69,7 +69,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, total_count); samediff::Threads::parallel_for(func, 0, total_count);
} else { } else {
const int total_count = batch_size * output_depth_by_output_area; const int total_count = batch_size * output_depth_by_output_area;
@ -93,7 +93,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, total_count); samediff::Threads::parallel_for(func, 0, total_count);
} }
} }

View File

@ -58,7 +58,7 @@ Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int
} }
}; };
sd::Threads::parallel_for(func, 0, indices.lengthOf()); samediff::Threads::parallel_for(func, 0, indices.lengthOf());
return numOfBadIndx; return numOfBadIndx;
} }
@ -87,7 +87,7 @@ void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indic
} }
}; };
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
} }
else { // outRank > 1 else { // outRank > 1
@ -107,7 +107,7 @@ void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indic
} }
}; };
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
} }
} }
@ -129,7 +129,7 @@ void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& ind
} }
}; };
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
} }
else { else {
std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1}); std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1});
@ -154,7 +154,7 @@ void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& ind
} }
}; };
sd::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); samediff::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
} }
} }
@ -176,7 +176,7 @@ void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray
} }
}; };
sd::Threads::parallel_for(func, 0, indicesLen); samediff::Threads::parallel_for(func, 0, indicesLen);
} else { } else {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) { for (auto i = start; i < stop; i++) {
@ -186,7 +186,7 @@ void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray
} }
}; };
sd::Threads::parallel_for(func, 0, indicesLen); samediff::Threads::parallel_for(func, 0, indicesLen);
} }
} }

View File

@ -173,7 +173,7 @@ namespace helpers {
meanV.p<T>(e, meanV.e<T>(e) + listOfTensors.at(i)->e<T>(e)); meanV.p<T>(e, meanV.e<T>(e) + listOfTensors.at(i)->e<T>(e));
} }
}; };
sd::Threads::parallel_for(func, 0, meanT->lengthOf()); samediff::Threads::parallel_for(func, 0, meanT->lengthOf());
count++; count++;
} }
@ -227,7 +227,7 @@ namespace helpers {
sumT->p(e, sumT->e<T>(e) + listOfTensors.at(i)->e<T>(e)); sumT->p(e, sumT->e<T>(e) + listOfTensors.at(i)->e<T>(e));
} }
}; };
sd::Threads::parallel_for(func, 0, sumT->lengthOf()); samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
} }
else { else {
idx = indices->e<int>(i); idx = indices->e<int>(i);
@ -276,7 +276,7 @@ namespace helpers {
sumT->p(e, sumT->e<T>(e) * listOfTensors.at(i)->e<T>(e)); sumT->p(e, sumT->e<T>(e) * listOfTensors.at(i)->e<T>(e));
} }
}; };
sd::Threads::parallel_for(func, 0, sumT->lengthOf()); samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
} }
else { else {
idx = indices->e<int>(i); idx = indices->e<int>(i);
@ -631,7 +631,7 @@ namespace helpers {
output->p(e, gradOut->e<T>(classNum)); output->p(e, gradOut->e<T>(classNum));
} }
}; };
sd::Threads::parallel_for(func, 0, loop_size); samediff::Threads::parallel_for(func, 0, loop_size);
} }
else { else {
std::vector<int> restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); std::vector<int> restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -658,7 +658,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(func, 0, indices->lengthOf()); samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
} }
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
@ -681,7 +681,7 @@ namespace helpers {
output->p(e, gradOut->e<double>(classNum)); output->p(e, gradOut->e<double>(classNum));
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf()); samediff::Threads::parallel_for(func, 0, input->lengthOf());
} }
else { else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -711,7 +711,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(func, 0, indices->lengthOf()); samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
} }
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
@ -758,7 +758,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
@ -791,7 +791,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
return Status::OK(); return Status::OK();
} }
@ -828,7 +828,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
@ -894,7 +894,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, input->lengthOf()); samediff::Threads::parallel_for(func, 0, input->lengthOf());
} }
else { else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -918,7 +918,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
@ -993,7 +993,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
return Status::OK(); return Status::OK();
} }
@ -1010,7 +1010,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, indices->lengthOf()); samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
else { else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -1032,7 +1032,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
return Status::OK(); return Status::OK();
@ -1059,7 +1059,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
else { else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -1081,7 +1081,7 @@ namespace helpers {
} }
//}; //};
//sd::Threads::parallel_for(func, 0, indices->lengthOf()); //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace helpers {
output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f)); output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
}; };
sd::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1); samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1);
} }
void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {

View File

@ -425,7 +425,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
} }
BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads), FLOAT_TYPES);
@ -577,7 +577,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
} }
BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES);

View File

@ -136,7 +136,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func,0, numOfSubArrs); samediff::Threads::parallel_tad(func,0, numOfSubArrs);
} }
#endif #endif
@ -168,7 +168,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func,0, numOfSubArrs); samediff::Threads::parallel_tad(func,0, numOfSubArrs);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -228,7 +228,7 @@ namespace sd {
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
delete []offsets; delete []offsets;
} }

View File

@ -48,7 +48,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
} }
// --------------------------------------------------------------------------------------------------------------------------------------- // // --------------------------------------------------------------------------------------------------------------------------------------- //

View File

@ -115,7 +115,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_for(func, 0, input.lengthOf()); samediff::Threads::parallel_for(func, 0, input.lengthOf());
} }
void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis) { void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis) {

View File

@ -184,7 +184,7 @@ static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray
} }
}; };
sd::Threads::parallel_tad(func, 0, ncols); samediff::Threads::parallel_tad(func, 0, ncols);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -303,7 +303,7 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr
} }
}; };
sd::Threads::parallel_tad(func, 0, ncols); samediff::Threads::parallel_tad(func, 0, ncols);
// gradB // gradB
gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K]

View File

@ -43,7 +43,7 @@ static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, c
output.p<T>(i, inArrs[i]->t<T>(0)); output.p<T>(i, inArrs[i]->t<T>(0));
}; };
sd::Threads::parallel_for(func, 0, numOfSubArrs); samediff::Threads::parallel_for(func, 0, numOfSubArrs);
} }
else { else {
@ -63,7 +63,7 @@ static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, c
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
} }
@ -88,7 +88,7 @@ static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs,
outArrs[i]->p<T>(0, input.t<T>(i)); outArrs[i]->p<T>(0, input.t<T>(i));
}; };
sd::Threads::parallel_for(func, 0, numOfSubArrs); samediff::Threads::parallel_for(func, 0, numOfSubArrs);
} }
else { else {
@ -107,7 +107,7 @@ static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs,
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
} }

View File

@ -163,7 +163,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(func, 0, target->lengthOf()); samediff::Threads::parallel_tad(func, 0, target->lengthOf());
} }
return status; return status;

View File

@ -48,7 +48,7 @@ static void triuBP_(sd::LaunchContext * context, const NDArray& input, const NDA
dOdI.t<T>(i) = static_cast<T>(1.f); dOdI.t<T>(i) = static_cast<T>(1.f);
} }
}; };
sd::Threads::parallel_for(func, 0, dLen); samediff::Threads::parallel_for(func, 0, dLen);
// FIXME: !!! // FIXME: !!!
gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO
@ -68,7 +68,7 @@ static void trace_(const NDArray& input, NDArray& output) {
for (auto i = start; i < stop; i++) for (auto i = start; i < stop; i++)
output.p(i, setOfSubArrs.at(i)->getTrace()); output.p(i, setOfSubArrs.at(i)->getTrace());
}; };
sd::Threads::parallel_for(func, 0, setOfSubArrs.size()); samediff::Threads::parallel_for(func, 0, setOfSubArrs.size());
} }
void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) { void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) {
@ -211,7 +211,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
else { // REFLECT and SYMMETRIC cases else { // REFLECT and SYMMETRIC cases
@ -237,7 +237,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
} }
@ -606,7 +606,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
} }
}; };
sd::Threads::parallel_tad(func, 0, zLen); samediff::Threads::parallel_tad(func, 0, zLen);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -654,7 +654,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
output->p(e, input->e<T>(indices->e<Nd4jLong>(e))); output->p(e, input->e<T>(indices->e<Nd4jLong>(e)));
}; };
sd::Threads::parallel_for(func, 0, indices->lengthOf()); samediff::Threads::parallel_for(func, 0, indices->lengthOf());
} }
else { else {
@ -670,7 +670,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
} }
else { else {
@ -694,7 +694,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
} }
}; };
sd::Threads::parallel_tad(func, 0, numOfSubArrs); samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
} }
} }
} }
@ -714,7 +714,7 @@ void eye(sd::LaunchContext * context, NDArray& output) {
arrs.at(i)->setIdentity(); arrs.at(i)->setIdentity();
}; };
sd::Threads::parallel_tad(func, 0, arrs.size()); samediff::Threads::parallel_tad(func, 0, arrs.size());
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -772,7 +772,7 @@ void scatterUpdate(sd::LaunchContext * context, NDArray& input, NDArray& updates
} }
}; };
sd::Threads::parallel_tad(func, 0, indices.size()); samediff::Threads::parallel_tad(func, 0, indices.size());
} }
@ -792,7 +792,7 @@ void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input,
} }
}; };
sd::Threads::parallel_for(func, 0, len); samediff::Threads::parallel_for(func, 0, len);
} }
break; break;
@ -824,7 +824,7 @@ static void mergeMaxIndex_(const std::vector<NDArray*>& inArrs, NDArray& output)
} }
}; };
sd::Threads::parallel_for(func, 0, x->lengthOf()); samediff::Threads::parallel_for(func, 0, x->lengthOf());
} }
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeMaxIndex(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
@ -850,7 +850,7 @@ static void mergeMax_(const std::vector<NDArray*>& inArrs, NDArray& output) {
} }
}; };
sd::Threads::parallel_for(func, 0, x->lengthOf()); samediff::Threads::parallel_for(func, 0, x->lengthOf());
} }
void mergeMax(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeMax(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
@ -875,7 +875,7 @@ static void mergeAvg_(const std::vector<NDArray*>& inArrs, NDArray& output) {
} }
}; };
sd::Threads::parallel_for(func, 0, x->lengthOf()); samediff::Threads::parallel_for(func, 0, x->lengthOf());
} }
void mergeAvg(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeAvg(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
@ -900,7 +900,7 @@ static void mergeAdd_(const std::vector<NDArray*>& inArrs, NDArray& output) {
} }
}; };
sd::Threads::parallel_for(func, 0, x->lengthOf()); samediff::Threads::parallel_for(func, 0, x->lengthOf());
} }
void mergeAdd(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeAdd(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES);
@ -934,7 +934,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>&
*listOfInSubArrs.at(i) *= normClip / iNormActual; *listOfInSubArrs.at(i) *= normClip / iNormActual;
} }
}; };
sd::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
} }
} }
else { else {
@ -963,7 +963,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>&
*outputSubArr *= clipNorm / iNormActual; *outputSubArr *= clipNorm / iNormActual;
} }
}; };
sd::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
} }
} }
} }
@ -1079,7 +1079,7 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g
gradISubArr->assign(gradOSubArr); gradISubArr->assign(gradOSubArr);
} }
}; };
sd::Threads::parallel_tad(func, 0, gradISubArrs.size()); samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
} }
} }
@ -1215,7 +1215,7 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o
} }
}; };
sd::Threads::parallel_for(func, 0, outLen); samediff::Threads::parallel_for(func, 0, outLen);
} }
} }

View File

@ -99,7 +99,7 @@ namespace helpers {
} }
}; };
sd::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1); samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1);
return Status::OK(); return Status::OK();
@ -128,7 +128,7 @@ namespace helpers {
} }
} }
}; };
sd::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
} }
int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) {

View File

@ -68,7 +68,7 @@ static void zeta_(sd::LaunchContext * context, const NDArray& x, const NDArray&
z.p(i, zetaScalar<T>(x.e<T>(i), q.e<T>(i))); z.p(i, zetaScalar<T>(x.e<T>(i), q.e<T>(i)));
}; };
sd::Threads::parallel_for(func, 0, xLen); samediff::Threads::parallel_for(func, 0, xLen);
} }
void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) { void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) {

View File

@ -77,7 +77,7 @@ void FORCEINLINE cross(sd::LaunchContext * context, NDArray *a, NDArray *b, NDAr
} }
}; };
sd::Threads::parallel_tad(func, 0, tads); samediff::Threads::parallel_tad(func, 0, tads);
} }
void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output); void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output);

Some files were not shown because too many files have changed in this diff Show More