Revert "OpenMP Threads execution (#297)" (#299)

This reverts commit dd2043ef48.
master
raver119 2020-03-09 08:22:49 +03:00 committed by GitHub
parent dd2043ef48
commit 57210b936c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
117 changed files with 593 additions and 890 deletions

View File

@ -21,9 +21,9 @@ if (SD_CUDA)
enable_language(CUDA)
set(CMAKE_CUDA_STANDARD 11)
set(DEFAULT_ENGINE "sd::ENGINE_CUDA")
set(DEFAULT_ENGINE "samediff::ENGINE_CUDA")
else()
set(DEFAULT_ENGINE "sd::ENGINE_CPU")
set(DEFAULT_ENGINE "samediff::ENGINE_CPU")
endif()
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively

View File

@ -56,7 +56,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
sd::Threads::parallel_for(func, 0, length);
samediff::Threads::parallel_for(func, 0, length);
#endif
delete[] tmp;
@ -114,7 +114,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
sd::Threads::parallel_for(func, 0, length);
samediff::Threads::parallel_for(func, 0, length);
#endif
delete[] tmp;
@ -142,7 +142,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
sd::Threads::parallel_for(func, 0, length);
samediff::Threads::parallel_for(func, 0, length);
#endif
delete[] tmp;
}
@ -168,7 +168,7 @@ namespace sd {
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
sd::Threads::parallel_for(func, 0, length);
samediff::Threads::parallel_for(func, 0, length);
#endif
delete[] tmp;
}

View File

@ -515,7 +515,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
@ -582,7 +582,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::stri
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
@ -648,7 +648,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
}
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
@ -714,7 +714,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
}
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
@ -780,7 +780,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
}
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
@ -846,7 +846,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
}
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
@ -2393,7 +2393,7 @@ NDArray NDArray::asS() const {
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
registerPrimaryUse({ &res }, { this });
@ -3466,7 +3466,7 @@ NDArray NDArray::dup(const char newOrder) const {
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
}
@ -3479,7 +3479,7 @@ NDArray NDArray::dup(const char newOrder) const {
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
}
@ -3491,7 +3491,7 @@ NDArray NDArray::dup(const char newOrder) const {
}
};
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
}

View File

@ -115,7 +115,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
}
};
sd::Threads::parallel_for(func, 0, zLen);
samediff::Threads::parallel_for(func, 0, zLen);
}
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
@ -159,7 +159,7 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
}
};
sd::Threads::parallel_for(func, 0, length);
samediff::Threads::parallel_for(func, 0, length);
}
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
@ -272,7 +272,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
}
};
sd::Threads::parallel_for(func, 0, resultLen);
samediff::Threads::parallel_for(func, 0, resultLen);
}
else {
@ -284,7 +284,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
}
};
sd::Threads::parallel_for(func, 0, resultLen);
samediff::Threads::parallel_for(func, 0, resultLen);
}
result.tickWriteHost();
return result;
@ -397,7 +397,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
}
};
sd::Threads::parallel_for(func, 0, zLen);
samediff::Threads::parallel_for(func, 0, zLen);
}
//////////////////////////////////////////////////////////////////////////

View File

@ -26,7 +26,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
z[e] = func(f[e], s[e], t[e]);
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
if (f == z) {
@ -40,7 +40,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
auto loop = PRAGMA_THREADS_FOR {
@ -54,7 +54,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
}
}
}
@ -97,7 +97,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
z[e] = func(f[e], s[e]);
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
if (f == z) {
@ -110,7 +110,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
auto loop = PRAGMA_THREADS_FOR {
@ -123,7 +123,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
}
}
}
@ -160,7 +160,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
z[e] = func(f[e]);
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
if (f == z) {
@ -172,7 +172,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
auto loop = PRAGMA_THREADS_FOR {
@ -184,7 +184,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
}
}
}
@ -221,7 +221,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
z[e] = func(e, f[e]);
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
if (f == z) {
@ -233,7 +233,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
auto loop = PRAGMA_THREADS_FOR {
@ -245,7 +245,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
}
}
}
@ -287,7 +287,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
z[e] = func((Nd4jLong) e, f[e], s[e]);
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
if (f == z) {
@ -300,7 +300,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
} else {
auto loop = PRAGMA_THREADS_FOR {
@ -313,7 +313,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
}
};
sd::Threads::parallel_for(loop, 0, _length);
samediff::Threads::parallel_for(loop, 0, _length);
}
}
}

View File

@ -27,7 +27,7 @@
#include <atomic>
#include <condition_variable>
namespace sd {
namespace samediff {
template <typename T>
class BlockingQueue {
private:

View File

@ -29,7 +29,7 @@
#include <mutex>
#include <condition_variable>
namespace sd {
namespace samediff {
/**
* This class is suited for passing functions to execution threads without queues
*/

View File

@ -27,7 +27,7 @@
#include <condition_variable>
#include <system/op_boilerplate.h>
namespace sd {
namespace samediff {
class CallableWithArguments {
FUNC_DO _function_do;
FUNC_1D _function_1d;

View File

@ -21,7 +21,7 @@
#ifndef SD_ENGINE_H
#define SD_ENGINE_H
namespace sd {
namespace samediff {
enum Engine {
ENGINE_CPU = 0,
ENGINE_CUDA = 1,

View File

@ -21,7 +21,7 @@
#ifndef SD_EXECUTIONMODE_H
#define SD_EXECUTIONMODE_H
namespace sd {
namespace samediff {
enum ExecutionMode {
MODE_UNDEFINED = 0,
MODE_TRAINING = 1,

View File

@ -32,7 +32,7 @@
#include <execution/Ticket.h>
#include <queue>
namespace sd {
namespace samediff {
class ND4J_EXPORT ThreadPool {
private:
static ThreadPool* _INSTANCE;

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_THREADS_H
#define SAMEDIFF_THREADS_H
@ -26,7 +26,7 @@
#include <system/Environment.h>
#include <system/op_enums.h>
namespace sd {
namespace samediff {
class ND4J_EXPORT ThreadsHelper {
public:
static int numberOfThreads(int maxThreads, uint64_t numberOfElements);
@ -95,14 +95,6 @@ namespace sd {
};
class ND4J_EXPORT Threads {
#ifdef _OPENMP
public:
static std::mutex gThreadmutex;
static uint64_t _nFreeThreads;
static bool tryAcquire(int numThreads);
static bool freeThreads(int numThreads);
#endif
public:
/**
* This function executes 1 dimensional loop for a given number of threads

View File

@ -28,7 +28,7 @@
#include <atomic>
#include <mutex>
namespace sd {
namespace samediff {
class ND4J_EXPORT Ticket {
private:
bool _acquired = false;

View File

@ -22,7 +22,7 @@
#include <execution/CallableWithArguments.h>
#include <thread>
namespace sd {
namespace samediff {
template <typename T>
BlockingQueue<T>::BlockingQueue(int queueSize) {
_size = 0;

View File

@ -21,7 +21,7 @@
#include <execution/CallableInterface.h>
#include <helpers/logger.h>
namespace sd {
namespace samediff {
CallableInterface::CallableInterface() {
// initial state is available
_available = true;

View File

@ -20,7 +20,7 @@
#include <execution/CallableWithArguments.h>
namespace sd {
namespace samediff {
CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) {
_function_do = func;
_finished = false;

View File

@ -26,7 +26,7 @@
//#include <windows.h>
#endif
namespace sd {
namespace samediff {
// this function executed once per thread, it polls functions from queue, and executes them via wrapper
static void executionLoop_(int thread_id, BlockingQueue<CallableWithArguments*> *queue) {
@ -183,7 +183,7 @@ namespace sd {
}
}
void ThreadPool::release(sd::Ticket *ticket) {
void ThreadPool::release(samediff::Ticket *ticket) {
// returning ticket back to the queue
std::unique_lock<std::mutex> lock(_lock);
_tickets.push(ticket);

View File

@ -25,14 +25,8 @@
#include <math/templatemath.h>
#include <helpers/shape.h>
#ifdef _OPENMP
#include <omp.h>
#endif
namespace sd {
namespace samediff {
int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) {
// let's see how many threads we actually need first
@ -276,7 +270,7 @@ namespace sd {
auto remY = iters_y % maxThreads;
// in some cases there's nothing to think about, part 2
if ((iters_x >= maxThreads && remX == 0) || (iters_y >= maxThreads && remY == 0))
if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0))
return maxThreads;
// at this point we suppose that there's no loop perfectly matches number of our threads
@ -345,35 +339,11 @@ namespace sd {
return 1;
}
#ifdef _OPENMP
std::mutex Threads::gThreadmutex;
uint64_t Threads::_nFreeThreads = sd::Environment::getInstance()->maxThreads();
bool Threads::tryAcquire(int numThreads){
std::lock_guard<std::mutex> lock( gThreadmutex );
auto nThreads = _nFreeThreads - numThreads;
if(nThreads >= 1){
_nFreeThreads = nThreads;
return true;
}
return false;
}
bool Threads::freeThreads(int numThreads){
std::lock_guard<std::mutex> lock( gThreadmutex );
_nFreeThreads += numThreads;
// check if correct number of threads
return _nFreeThreads > sd::Environment::getInstance()->maxThreads();
}
#endif
int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
if (start > stop)
throw std::runtime_error("Threads::parallel_for got start > stop");
auto delta = (stop - start) / increment;
auto delta = (stop - start);
if (numThreads > delta)
numThreads = delta;
@ -387,26 +357,6 @@ namespace sd {
return 1;
}
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (auto e = start; e < stop; e += increment) {
function(omp_get_thread_num(), e, e + 1, 1);
}
freeThreads(numThreads);
return numThreads;
}
else {
// if there were no threads available - we'll execute function right within current thread
function(0, start, stop, increment);
// we tell that parallelism request declined
return 1;
}
#else
sd::Environment::getInstance()->maxThreads();
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
// if we got our threads - we'll run our jobs here
@ -429,15 +379,13 @@ namespace sd {
// we tell that parallelism request succeeded
return numThreads;
}
else {
} else {
// if there were no threads available - we'll execute function right within current thread
function(0, start, stop, increment);
// we tell that parallelism request declined
return 1;
}
#endif
}
int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
@ -500,30 +448,7 @@ namespace sd {
// but we still mimic multithreaded execution
return numThreads;
}
else {
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads) collapse(2)
for (auto x = startX; x < stopX; x += incX) {
for (auto y = startY; y < stopY; y += incY) {
function(omp_get_thread_num(), x, x+1, 1, y, y+1, 1);
}
}
freeThreads(numThreads);
return numThreads;
}
else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY);
// we tell that parallelism request declined
return 1;
}
#else
} else {
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
@ -538,15 +463,13 @@ namespace sd {
ticket->waitAndRelease();
return numThreads;
}
else {
} else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY);
// we tell that parallelism request declined
return 1;
}
#endif
};
}
@ -561,35 +484,6 @@ namespace sd {
if (startZ > stopZ)
throw std::runtime_error("Threads::parallel_for got startZ > stopZ");
if (numThreads == 1) {
// loop is too small - executing function as is
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
return 1;
}
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads) collapse(3)
for (auto x = startX; x < stopX; x += incX) {
for (auto y = startY; y < stopY; y += incY) {
for (auto z = startZ; z < stopZ; z += incZ) {
function(omp_get_thread_num(), x, x+1, 1, y, y+1, 1, z, z+1, 1);
}
}
}
freeThreads(numThreads);
return numThreads;
}
else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
// we tell that parallelism request declined
return 1;
}
#else
auto delta_x = stopX - startX;
auto delta_y = stopY - startY;
auto delta_z = stopZ - startZ;
@ -621,43 +515,17 @@ namespace sd {
// we tell that parallelism request succeeded
return numThreads;
}
else {
} else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
// we tell that parallelism request declined
return 1;
}
#endif
}
int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
if (numThreads == 1) {
function(0, numThreads);
return 1;
}
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (int e = 0; e < numThreads; e++) {
function(e, numThreads);
}
freeThreads(numThreads);
return numThreads;
}
else {
// if there's no threads available - we'll execute function sequentially one by one
for (uint64_t e = 0; e < numThreads; e++)
function(e, numThreads);
return numThreads;
}
#else
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket != nullptr) {
@ -670,15 +538,14 @@ namespace sd {
ticket->waitAndRelease();
return numThreads;
}
else {
} else {
// if there's no threads available - we'll execute function sequentially one by one
for (uint64_t e = 0; e < numThreads; e++)
function(e, numThreads);
return numThreads;
}
#endif
return numThreads;
}
@ -698,30 +565,14 @@ namespace sd {
if (numThreads == 1)
return function(0, start, stop, increment);
// create temporary array
int64_t intermediatery[256];
auto span = delta / numThreads;
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (int e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = span * (e + 1) + start;
intermediatery[e] = function(e, start_, e == numThreads - 1 ? stop : stop_, increment);
}
freeThreads(numThreads);
}
else{
// if there were no thre ads available - we'll execute function right within current thread
return function(0, start, stop, increment);
}
#else
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket == nullptr)
return function(0, start, stop, increment);
// create temporary array
int64_t intermediatery[256];
auto span = delta / numThreads;
// execute threads in parallel
for (uint32_t e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
@ -735,8 +586,6 @@ namespace sd {
ticket->waitAndRelease();
#endif
// aggregate results in single thread
for (uint64_t e = 1; e < numThreads; e++)
intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]);
@ -760,33 +609,14 @@ namespace sd {
if (numThreads == 1)
return function(0, start, stop, increment);
// create temporary array
double intermediatery[256];
auto span = delta / numThreads;
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (int e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = span * (e + 1) + start;
intermediatery[e] = function(e, start_, e == numThreads - 1 ? stop : stop_, increment);
}
freeThreads(numThreads);
}
else{
// if there were no thre ads available - we'll execute function right within current thread
return function(0, start, stop, increment);
}
#else
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket == nullptr)
return function(0, start, stop, increment);
// create temporary array
double intermediatery[256];
auto span = delta / numThreads;
// execute threads in parallel
for (uint32_t e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
@ -800,8 +630,6 @@ namespace sd {
ticket->waitAndRelease();
#endif
// aggregate results in single thread
for (uint64_t e = 1; e < numThreads; e++)
intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]);
@ -811,7 +639,7 @@ namespace sd {
}
int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size, uint32_t req_numThreads) {
int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size , uint32_t req_numThreads) {
if (start > stop)
throw std::runtime_error("Threads::parallel_for got start > stop");
auto num_elements = (stop - start);
@ -819,7 +647,6 @@ namespace sd {
//so we will parition considering delta but not total elements
auto delta = (stop - start) / increment;
// in some cases we just fire func as is
if (delta == 0 || req_numThreads == 1) {
function(0, start, stop, increment);
@ -827,24 +654,7 @@ namespace sd {
}
int numThreads = 0;
struct th_span {
Nd4jLong start;
Nd4jLong end;
};
#ifdef _OPENMP
constexpr int max_thread_count = 8;
#else
constexpr int max_thread_count = 1024;
#endif
th_span thread_spans[max_thread_count];
req_numThreads = req_numThreads > max_thread_count ? max_thread_count : req_numThreads;
#ifdef _OPENMP
int adjusted_numThreads = max_thread_count;
#else
int adjusted_numThreads = sd::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size));
#endif
int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size));
if (adjusted_numThreads > delta)
adjusted_numThreads = delta;
@ -853,15 +663,13 @@ namespace sd {
function(0, start, stop, increment);
return 1;
}
//take span as ceil
auto spand = std::ceil((double)delta / (double)adjusted_numThreads);
numThreads = static_cast<int>(std::ceil((double)delta / spand));
auto span = static_cast<Nd4jLong>(spand);
auto ticket = samediff::ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
//tail_add is additional value of the last part
//it could be negative or positive
//we will spread that value across
@ -886,43 +694,18 @@ namespace sd {
for (int i = 0; i < last; i++) {
end = begin + span1 * increment;
// putting the task into the queue for a given thread
thread_spans[i].start = begin;
thread_spans[i].end = end;
ticket->enqueue(i, numThreads, function, begin, end, increment);
begin = end;
}
for (int i = last; i < numThreads - 1; i++) {
end = begin + span2 * increment;
// putting the task into the queue for a given thread
thread_spans[i].start = begin;
thread_spans[i].end = end;
ticket->enqueue(i, numThreads, function, begin, end, increment);
begin = end;
}
//for last one enqueue last offset as stop
//we need it in case our ((stop-start) % increment ) > 0
thread_spans[numThreads - 1].start = begin;
thread_spans[numThreads - 1].end = stop;
#ifdef _OPENMP
if (tryAcquire(numThreads)) {
#pragma omp parallel for num_threads(numThreads)
for (size_t j = 0; j < numThreads; j++) {
function(j, thread_spans[j].start, thread_spans[j].end, increment);
}
freeThreads(numThreads);
return numThreads;
}
else {
function(0, start, stop, increment);
// we tell that parallelism request declined
return 1;
}
#else
auto ticket = sd::ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
for (size_t j = 0; j < numThreads; j++) {
ticket->enqueue(j, numThreads, function, thread_spans[j].start, thread_spans[j].end, increment);
}
ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, increment);
// block and wait till all threads finished the job
ticket->waitAndRelease();
// we tell that parallelism request succeeded
@ -934,8 +717,7 @@ namespace sd {
// we tell that parallelism request declined
return 1;
}
#endif
}
}

View File

@ -23,7 +23,7 @@
#include <helpers/logger.h>
#include <array>
namespace sd {
namespace samediff {
Ticket::Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues) {
_acquired = true;
_queues = queues;
@ -38,7 +38,7 @@ namespace sd {
return _acquired;
}
void Ticket::enqueue(int thread_id, sd::CallableWithArguments *callable) {
void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) {
_queues[thread_id]->put(callable);
_callables.emplace_back(callable);
}
@ -88,7 +88,7 @@ namespace sd {
}
void Ticket::attach(uint32_t thread_id, sd::CallableInterface *interface) {
void Ticket::attach(uint32_t thread_id, samediff::CallableInterface *interface) {
_interfaces[thread_id] = interface;
}
}

View File

@ -112,7 +112,7 @@ namespace sd {
sd::random::RandomBuffer* getRNG();
void setRNG(sd::random::RandomBuffer* rng);
void setTargetEngine(sd::Engine engine);
void setTargetEngine(samediff::Engine engine);
VariableSpace *getVariableSpace();
@ -228,8 +228,8 @@ namespace sd {
void setShapeFunctionOverride(bool reallyOverride);
bool shapeFunctionOverride();
sd::ExecutionMode executionMode();
void setExecutionMode(sd::ExecutionMode executionMode);
samediff::ExecutionMode executionMode();
void setExecutionMode(samediff::ExecutionMode executionMode);
bool isTraining();
bool isInference();

View File

@ -64,9 +64,9 @@ namespace sd {
bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN();
// target engine for execution
sd::Engine _engine = DEFAULT_ENGINE;
samediff::Engine _engine = DEFAULT_ENGINE;
sd::ExecutionMode _execMode = sd::ExecutionMode::MODE_UNDEFINED;
samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED;
public:
explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false);
~ContextPrototype() = default;
@ -99,7 +99,7 @@ namespace sd {
std::vector<sd::DataType>* getDArguments();
std::vector<int>* getAxis();
sd::Engine engine();
samediff::Engine engine();
size_t numT();
size_t numI();

View File

@ -107,7 +107,7 @@ namespace sd {
delete _context;
}
void Context::setTargetEngine(sd::Engine engine) {
void Context::setTargetEngine(samediff::Engine engine) {
_engine = engine;
}
@ -548,20 +548,20 @@ namespace sd {
return _shapeFunctionOverride;
}
sd::ExecutionMode Context::executionMode() {
samediff::ExecutionMode Context::executionMode() {
return _execMode;
}
void Context::setExecutionMode(sd::ExecutionMode executionMode) {
void Context::setExecutionMode(samediff::ExecutionMode executionMode) {
_execMode = executionMode;
}
bool Context::isTraining() {
return _execMode == sd::ExecutionMode::MODE_TRAINING;
return _execMode == samediff::ExecutionMode::MODE_TRAINING;
}
bool Context::isInference() {
return _execMode == sd::ExecutionMode::MODE_INFERENCE;
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
}
void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) {

View File

@ -59,7 +59,7 @@ namespace sd {
}
}
sd::Engine ContextPrototype::engine() {
samediff::Engine ContextPrototype::engine() {
return _engine;
}

View File

@ -511,7 +511,7 @@ namespace sd {
//*********************************************//
case LoopKind::EWS1: {
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
int64_t start = span.startX(), stop = span.stopX();
for (auto i = start; i < stop; i++)
@ -524,7 +524,7 @@ namespace sd {
const uint xEws = shape::elementWiseStride(xShapeInfo);
const uint zEws = shape::elementWiseStride(zShapeInfo);
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
int64_t start = span.startX(), stop = span.stopX();
for (auto i = start; i < stop; i++)
@ -538,7 +538,7 @@ namespace sd {
uint castXShapeInfo[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, castXShapeInfo);
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
int64_t start = span.startX(), stop = span.stopX();
if (zEws > 1) {
@ -558,7 +558,7 @@ namespace sd {
//*********************************************//
case LoopKind::RANK1: {
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams);
@ -570,8 +570,8 @@ namespace sd {
auto uXShape0 = static_cast<uint>(xShape[0]);
auto uXShape1 = static_cast<uint>(xShape[1]);
auto loop = sd::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
auto span = sd::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) {
auto z0 = i0 * zStride[0];
@ -589,8 +589,8 @@ namespace sd {
auto uXShape1 = xShape[1];
auto uXShape2 = xShape[2];
auto loop = sd::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
auto span = sd::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
@ -611,8 +611,8 @@ namespace sd {
auto uXShape2 = xShape[2];
auto uXShape3 = xShape[3];
auto loop = sd::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
auto span = sd::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
for (auto i1 = span.startY(); i1 < span.stopY(); i1++)
@ -634,8 +634,8 @@ namespace sd {
auto uXShape3 = xShape[3];
auto uXShape4 = xShape[4];
auto loop = sd::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
auto span = sd::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
@ -666,7 +666,7 @@ namespace sd {
bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
for (auto i = span.startX(); i < span.stopX(); i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);

View File

@ -93,7 +93,7 @@ static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
}
};
sd::Threads::parallel_tad(func, 0, cLen);
samediff::Threads::parallel_tad(func, 0, cLen);
}
@ -146,7 +146,7 @@ static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const
}
};
sd::Threads::parallel_tad(func, 0, M);
samediff::Threads::parallel_tad(func, 0, M);
}
//////////////////////////////////////////////////////////////////////////////
@ -477,7 +477,7 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
}
};
sd::Threads::parallel_tad(func, 0, cLen);
samediff::Threads::parallel_tad(func, 0, cLen);
}
//////////////////////////////////////////////////////////////////////////
@ -669,7 +669,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
}
};
sd::Threads::parallel_tad(func, 0, M, 1, 0, N, 1);
samediff::Threads::parallel_tad(func, 0, M, 1, 0, N, 1);
}
//////////////////////////////////////////////////////////////////////////////
@ -703,7 +703,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
}
};
sd::Threads::parallel_tad(func, 0, M);
samediff::Threads::parallel_tad(func, 0, M);
}
*/

View File

@ -62,7 +62,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -83,7 +83,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -104,7 +104,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -131,7 +131,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -160,7 +160,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -191,7 +191,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -224,7 +224,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -248,7 +248,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -272,7 +272,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
break;
@ -299,7 +299,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
}
}

View File

@ -99,7 +99,7 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_m
return _stdDevValue;
};
_stdDevValue = sd::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf());
_stdDevValue = samediff::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf());
info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf());

View File

@ -199,7 +199,7 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
}
}
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
#endif
}
@ -237,7 +237,7 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen;
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
#endif
}
@ -273,7 +273,7 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen;
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
}
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
@ -308,7 +308,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen;
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
}
@ -348,7 +348,7 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen;
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
}
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
@ -384,7 +384,7 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen;
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
}
////////////////////////////////////////////////////////////////////////
@ -427,7 +427,7 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
};
auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
#endif
}
@ -462,7 +462,7 @@ void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
};
auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
@ -495,7 +495,7 @@ void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
};
auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
@ -534,7 +534,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
}
////////////////////////////////////////////////////////////////////////
@ -562,7 +562,7 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
}
////////////////////////////////////////////////////////////////////////
@ -590,7 +590,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
}
////////////////////////////////////////////////////////////////////////
@ -618,7 +618,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
}
////////////////////////////////////////////////////////////////////////
@ -791,7 +791,7 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
};
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
}
@ -820,7 +820,7 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
};
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
}
////////////////////////////////////////////////////////////////////////
@ -861,7 +861,7 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
};
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
}
@ -905,7 +905,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
};
auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
#endif
}
@ -942,7 +942,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
};
auto yLen = shape::length(hScalarShapeInfo);
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
#endif
}
@ -976,7 +976,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
};
auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
@ -1012,7 +1012,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
};
auto yLen = shape::length(hScalarShapeInfo);
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
}
////////////////////////////////////////////////////////////////////////
@ -1044,7 +1044,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
};
auto zLen = shape::length(hZShapeInfo);
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
@ -1080,7 +1080,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
};
auto yLen = shape::length(hScalarShapeInfo);
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
}
////////////////////////////////////////////////////////////////////////
@ -1193,7 +1193,7 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
};
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
////////////////////////////////////////////////////////////////////////
@ -1215,7 +1215,7 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
};
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
////////////////////////////////////////////////////////////////////////
@ -1243,7 +1243,7 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
};
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
}
@ -1266,7 +1266,7 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
};
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
////////////////////////////////////////////////////////////////////////
@ -1288,7 +1288,7 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
};
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
////////////////////////////////////////////////////////////////////////

View File

@ -1318,7 +1318,7 @@ void pullRowsGeneric(void *vx,
}
};
sd::Threads::parallel_tad(func, 0, n, 1, _threads);
samediff::Threads::parallel_tad(func, 0, n, 1, _threads);
}
void pullRows(Nd4jPointer *extraPointers,
@ -1377,7 +1377,7 @@ void tearGeneric(void *vx,
}
};
sd::Threads::parallel_tad(func,0, numTads);
samediff::Threads::parallel_tad(func,0, numTads);
}
void tear(Nd4jPointer *extraPointers,
@ -1530,7 +1530,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
}
};
sd::Threads::parallel_tad(func, 0, N);
samediff::Threads::parallel_tad(func, 0, N);
}
void shuffle(Nd4jPointer *extras,
@ -1944,7 +1944,7 @@ FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer
return cnt;
};
return sd::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
}
@ -2653,7 +2653,7 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
}
};
sd::Threads::parallel_do(func);
samediff::Threads::parallel_do(func);
}
////////////////////////////////////////////////////////////////////////
@ -2812,7 +2812,7 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
if (execMode < 0 || execMode > 2)
execMode = 0;
ptr->setExecutionMode((sd::ExecutionMode) execMode);
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
}
void ctxPurge(OpaqueContext* ptr) {

View File

@ -3799,7 +3799,7 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
if (execMode < 0 || execMode > 2)
execMode = 0;
ptr->setExecutionMode((sd::ExecutionMode) execMode);
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
}
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {

View File

@ -60,7 +60,7 @@ namespace sd {
}
}
};
sd::Threads::parallel_tad(func, 0, xArr.lengthOf());
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
}
@ -95,7 +95,7 @@ namespace sd {
}
}
};
sd::Threads::parallel_tad(func, 0, nLen, 1);
samediff::Threads::parallel_tad(func, 0, nLen, 1);
return;
}
@ -137,7 +137,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, zLen);
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y, typename Z>
@ -200,7 +200,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, zLen);
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y>
@ -263,7 +263,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, zLen);
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X>

View File

@ -79,7 +79,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
}
};
maxThreads = sd::Threads::parallel_for(func, 0, len, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
for (int e = 0; e < maxThreads; e++)
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
@ -95,7 +95,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
}
};
maxThreads = sd::Threads::parallel_for(func, 0, len, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
for (int e = 0; e < maxThreads; e++)
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);

View File

@ -67,7 +67,7 @@ namespace functions {
z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments);
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
else{
uint xShapeInfoCast[MAX_RANK];
@ -81,7 +81,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
}
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
@ -100,7 +100,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
@ -118,7 +118,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
@ -136,7 +136,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
else {
@ -157,7 +157,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
};
@ -192,7 +192,7 @@ namespace functions {
z[i] = OpClass::op(x[i], i, length, rng, extraArguments);
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
else{
auto func = PRAGMA_THREADS_FOR {
@ -203,7 +203,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
}
else {
@ -220,7 +220,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
}
@ -245,7 +245,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
else{
sd::OmpLaunchHelper info(length);
@ -261,7 +261,7 @@ namespace functions {
}
};
sd::Threads::parallel_for(func, 0, length, 1);
samediff::Threads::parallel_for(func, 0, length, 1);
}
}

View File

@ -208,26 +208,13 @@ namespace functions {
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];
PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) {
for (auto i = start; i < stop; i++)
@ -238,9 +225,7 @@ namespace functions {
}
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results
for (int e = 1; e < maxThreads; e++)

View File

@ -72,7 +72,7 @@ namespace functions {
auto startingValue = OpType::startingValue(x);
uint xShapeInfoCast[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];
PRAGMA_OMP_SIMD
@ -84,7 +84,7 @@ namespace functions {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results
for (int e = 1; e < maxThreads; e++)
@ -242,27 +242,13 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];
PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) {
for (auto i = start; i < stop; i++)
@ -273,9 +259,7 @@ namespace functions {
}
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results
for (int e = 1; e < maxThreads; e++)

View File

@ -67,7 +67,7 @@ namespace functions {
auto startingValue = OpType::startingValue(x);
uint xShapeInfoCast[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];
PRAGMA_OMP_SIMD
@ -79,7 +79,7 @@ namespace functions {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results
for (int e = 1; e < maxThreads; e++)
@ -231,26 +231,13 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
Z intermediate[64];
PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) {
for (auto i = start; i < stop; i++)
@ -261,9 +248,7 @@ namespace functions {
}
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results
for (int e = 1; e < maxThreads; e++)

View File

@ -69,7 +69,7 @@ namespace functions {
auto startingValue = OpType::startingValue(x);
uint xShapeInfoCast[MAX_RANK];
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
X intermediate[64];
PRAGMA_OMP_SIMD
@ -81,7 +81,7 @@ namespace functions {
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results
for (int e = 1; e < maxThreads; e++)
@ -240,26 +240,13 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
X intermediate[64];
PRAGMA_OMP_SIMD
for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x);
#ifdef _OPENMP
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
} else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
for (Nd4jLong i = 0; i < length; i++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
}
#else
auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) {
for (auto i = start; i < stop; i++)
@ -270,9 +257,7 @@ namespace functions {
}
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
#endif
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
// merge results
for (int e = 1; e < maxThreads; e++)

View File

@ -93,7 +93,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
}
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
@ -104,7 +104,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
}
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} else {
uint yShapeInfoCast[MAX_RANK];
const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
@ -117,7 +117,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
}
};
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
}
// merge step

View File

@ -187,7 +187,7 @@ namespace functions {
}
};
sd::Threads::parallel_tad(func, 0, resultLength, 1);
samediff::Threads::parallel_tad(func, 0, resultLength, 1);
}

View File

@ -86,7 +86,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, N);
samediff::Threads::parallel_for(func, 0, N);
}
template <typename T>
@ -184,7 +184,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
}
};
sd::Threads::parallel_for(func, 4, flimit);
samediff::Threads::parallel_for(func, 4, flimit);
}
/**
@ -206,7 +206,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
z[i] = static_cast<T>(static_cast<float>(x[i]));
}
};
sd::Threads::parallel_for(func, 0, N);
samediff::Threads::parallel_for(func, 0, N);
};
template void TypeCast::convertFromThreshold<float>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);

View File

@ -112,7 +112,7 @@ namespace sd {
*/
int prepareOutputs(Context& block);
virtual sd::EmptyHandling emptyHandling();
virtual samediff::EmptyHandling emptyHandling();
public:
// for special cases, like BooleanOps
DeclarableOp();

View File

@ -21,7 +21,7 @@
#ifndef SAMEDIFF_EMPTYHANDLING_H
#define SAMEDIFF_EMPTYHANDLING_H
namespace sd {
namespace samediff {
enum EmptyHandling {
EMPTY_SKIP = 1,
EMPTY_EXCEPTION = 2,

View File

@ -38,15 +38,15 @@
namespace std {
template <>
class hash<std::pair<Nd4jLong, sd::Engine>> {
class hash<std::pair<Nd4jLong, samediff::Engine>> {
public:
size_t operator()(const std::pair<Nd4jLong, sd::Engine>& k) const;
size_t operator()(const std::pair<Nd4jLong, samediff::Engine>& k) const;
};
template <>
class hash<std::pair<std::string, sd::Engine>> {
class hash<std::pair<std::string, samediff::Engine>> {
public:
size_t operator()(const std::pair<std::string, sd::Engine>& k) const;
size_t operator()(const std::pair<std::string, samediff::Engine>& k) const;
};
};
@ -87,8 +87,8 @@ namespace sd {
std::vector<sd::ops::DeclarableOp *> _uniqueD;
// pointers to platform-specific helpers
MAP_IMPL<std::pair<Nd4jLong, sd::Engine>, sd::ops::platforms::PlatformHelper*> _helpersLH;
MAP_IMPL<std::pair<std::string, sd::Engine>, sd::ops::platforms::PlatformHelper*> _helpersH;
MAP_IMPL<std::pair<Nd4jLong, samediff::Engine>, sd::ops::platforms::PlatformHelper*> _helpersLH;
MAP_IMPL<std::pair<std::string, samediff::Engine>, sd::ops::platforms::PlatformHelper*> _helpersH;
std::vector<sd::ops::platforms::PlatformHelper*> _uniqueH;
std::mutex _locker;
@ -119,13 +119,13 @@ namespace sd {
void registerHelper(sd::ops::platforms::PlatformHelper* op);
bool hasHelper(Nd4jLong hash, sd::Engine engine);
bool hasHelper(Nd4jLong hash, samediff::Engine engine);
sd::ops::DeclarableOp* getOperation(const char *name);
sd::ops::DeclarableOp* getOperation(Nd4jLong hash);
sd::ops::DeclarableOp* getOperation(std::string &name);
sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, sd::Engine engine);
sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine);
std::vector<Nd4jLong> getAllHashes();

View File

@ -37,7 +37,7 @@ namespace sd {
class ND4J_EXPORT PlatformHelper {
protected:
// target engine for this impl
sd::Engine _engine;
samediff::Engine _engine;
// name of the operation this helper is built for
std::string _name;
@ -45,13 +45,13 @@ namespace sd {
// hash of the operation this helper is built for
Nd4jLong _hash;
public:
PlatformHelper(const char *name, sd::Engine engine);
PlatformHelper(const char *name, samediff::Engine engine);
~PlatformHelper() = default;
std::string name();
sd::Engine engine();
samediff::Engine engine();
Nd4jLong hash();

View File

@ -174,7 +174,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(func, 0, N);
samediff::Threads::parallel_tad(func, 0, N);
}
void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) {

View File

@ -154,7 +154,7 @@ void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alp
}
};
sd::Threads::parallel_for(func, 0, inputLen);
samediff::Threads::parallel_for(func, 0, inputLen);
}
//////////////////////////////////////////////////////////////////////////

View File

@ -565,7 +565,7 @@ namespace sd {
}
};
//
sd::Threads::parallel_aligned_increment(func, 0, total_num, inc);
samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc);
}
else {
//NC...HW case here
@ -631,7 +631,7 @@ namespace sd {
}
};
//
sd::Threads::parallel_aligned_increment(func, 0, total_num, inc);
samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc);
}
}
//////////////////////////////////////////////////////////////////////////

View File

@ -55,7 +55,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
}
else {
@ -87,7 +87,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
}

View File

@ -56,7 +56,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
} else {
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
@ -84,7 +84,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
}

View File

@ -114,7 +114,7 @@ void bgemm_(const std::vector<NDArray*>& vA, const std::vector<NDArray*>& vB, st
}
};
sd::Threads::parallel_tad(func, 0, vaSize);
samediff::Threads::parallel_tad(func, 0, vaSize);
}
}

View File

@ -106,7 +106,7 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
delete []zOffsets;
};
sd::Threads::parallel_do(func, info._numThreads);
samediff::Threads::parallel_do(func, info._numThreads);
}
//////////////////////////////////////////////////////////////////////////
@ -178,7 +178,7 @@ static void batchnorm2_(const NDArray* input, const NDArray* mean, const NDArray
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf());
samediff::Threads::parallel_for(func, 0, input->lengthOf());
}
//////////////////////////////////////////////////////////////////////////

View File

@ -121,7 +121,7 @@ static void betaIncForArray(sd::LaunchContext * context, const NDArray& a, const
output.t<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i));
};
sd::Threads::parallel_for(func, 0, xLen);
samediff::Threads::parallel_for(func, 0, xLen);
}
///////////////////////////////////////////////////////////////////

View File

@ -89,7 +89,7 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
else {
@ -127,7 +127,7 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
}
};
sd::Threads::parallel_tad(func, 0, bS);
samediff::Threads::parallel_tad(func, 0, bS);
}
}

View File

@ -40,7 +40,7 @@ namespace sd {
}
return sum;
};
sumt = sd::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
} else {
//PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum)
auto func = PRAGMA_REDUCE_LONG {
@ -53,7 +53,7 @@ namespace sd {
return sum;
};
sumt = sd::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
}
//nd4j_printf("Sum: %lld\n", sumt)

View File

@ -40,7 +40,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, lLen);
samediff::Threads::parallel_for(func, 0, lLen);
}
void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {

View File

@ -101,7 +101,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1);
} else {
@ -139,7 +139,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1);
//func(0, 0, bS, 1, 0, oD, 1);
}
}
@ -215,7 +215,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, bS);
samediff::Threads::parallel_tad(func, 0, bS);
} else {
@ -251,7 +251,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, bS);
samediff::Threads::parallel_tad(func, 0, bS);
}
}
@ -606,7 +606,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1);
}
//////////////////////////////////////////////////////////////////////////
@ -663,7 +663,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
}
//////////////////////////////////////////////////////////////////////////
@ -716,7 +716,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1);
}
//////////////////////////////////////////////////////////////////////////
@ -777,7 +777,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1);
}
//////////////////////////////////////////////////////////////////////////
@ -860,7 +860,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
/*************************************************************************/
else if(poolingMode == 1) { // avg
@ -914,7 +914,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
/*************************************************************************/
else if(poolingMode == 2) { // pnorm
@ -963,7 +963,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
else {
nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
@ -1068,7 +1068,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
}
/*************************************************************************/
else if(poolingMode == 1) { // avg
@ -1131,7 +1131,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
}
/*************************************************************************/
else if(poolingMode == 2) { // pnorm
@ -1191,7 +1191,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
}
else {
nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
@ -1321,7 +1321,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
/*************************************************************************/
else if(poolingMode == 1) { // avg
@ -1379,7 +1379,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
/*************************************************************************/
else if(poolingMode == 2) { // pnorm
@ -1466,7 +1466,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
else {
nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
@ -1618,7 +1618,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
}
/*************************************************************************/
else if(poolingMode == 1) { // avg
@ -1679,7 +1679,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
}
/*************************************************************************/
else if(poolingMode == 2) { // pnorm
@ -1761,7 +1761,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
}
else {
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);

View File

@ -115,7 +115,7 @@ namespace sd {
}
};
sd::Threads::parallel_for(func, 0, cropHeight);
samediff::Threads::parallel_for(func, 0, cropHeight);
}
}
}

View File

@ -48,7 +48,7 @@ void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *
}
};
sd::Threads::parallel_tad(func, 0, tads);
samediff::Threads::parallel_tad(func, 0, tads);
}
}

View File

@ -65,7 +65,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, total_count);
samediff::Threads::parallel_for(func, 0, total_count);
} else {
const int total_count = batch_size * input_depth_by_input_area;
@ -89,7 +89,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, total_count);
samediff::Threads::parallel_for(func, 0, total_count);
}
}

View File

@ -35,7 +35,7 @@ static void diGamma_(const NDArray& x, NDArray& z) {
for (auto i = start; i < stop; i++)
z.p(i, diGammaScalar<T>(x.e<T>(i)));
};
sd::Threads::parallel_for(func, 0, x.lengthOf());
samediff::Threads::parallel_for(func, 0, x.lengthOf());
}
void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) {

View File

@ -87,7 +87,7 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
}
void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {

View File

@ -43,7 +43,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, inLen);
samediff::Threads::parallel_for(func, 0, inLen);
}
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
@ -137,7 +137,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf());
samediff::Threads::parallel_for(func, 0, input->lengthOf());
return Status::OK();
}

View File

@ -71,7 +71,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, outSize);
samediff::Threads::parallel_tad(func, 0, outSize);
}
}
template <typename T>
@ -177,7 +177,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, gradsSize);
samediff::Threads::parallel_tad(func, 0, gradsSize);
}
outputList[1]->assign(indices);

View File

@ -82,7 +82,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(func, 0, batchCount);
samediff::Threads::parallel_tad(func, 0, batchCount);
}

View File

@ -63,7 +63,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
output->p(i, input->e(indices->e<Nd4jLong>(i)));
};
sd::Threads::parallel_for(func, 0, output->lengthOf());
samediff::Threads::parallel_for(func, 0, output->lengthOf());
}
else {
@ -96,7 +96,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
else {
auto func = PRAGMA_THREADS_FOR {
@ -112,7 +112,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}
}
@ -148,7 +148,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
else {
@ -167,7 +167,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}

View File

@ -64,7 +64,7 @@ namespace sd {
}
};
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf);
maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) {
@ -75,7 +75,7 @@ namespace sd {
}
};
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf);
maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
} else {
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) {
@ -86,7 +86,7 @@ namespace sd {
}
};
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf);
maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
}
// accumulate intermediate variables into output array

View File

@ -54,7 +54,7 @@ namespace sd {
tempBuffer[b] = r;
}
};
sd::Threads::parallel_tad(func, 0, numBlocks);
samediff::Threads::parallel_tad(func, 0, numBlocks);
// we replace pointer with intermediate one, and repeat only one chunk left
int iterationCount = 0;
@ -76,7 +76,7 @@ namespace sd {
tempResult[b] = r;
}
};
sd::Threads::parallel_tad(func2, 0, numBlocks);
samediff::Threads::parallel_tad(func2, 0, numBlocks);
iterationCount++;

View File

@ -90,7 +90,7 @@ static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray&
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
}
else {
@ -124,7 +124,7 @@ static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray&
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
}
}

View File

@ -149,7 +149,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_tad(func, 0, batchSize);
samediff::Threads::parallel_tad(func, 0, batchSize);
}

View File

@ -178,7 +178,7 @@ namespace helpers {
interpolationData[i]._interpolarValue = in - in_f;
}
};
sd::Threads::parallel_for(func, 0, outSize);
samediff::Threads::parallel_for(func, 0, outSize);
}
/**
@ -240,7 +240,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_tad(func, 0, batchSize);
samediff::Threads::parallel_tad(func, 0, batchSize);
}
template<typename X, typename Z>
@ -285,7 +285,7 @@ namespace helpers {
xs[i]._topIndex *= channels;
}
};
sd::Threads::parallel_for(func, 0, xsSize);
samediff::Threads::parallel_for(func, 0, xsSize);
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
return Status::OK();
@ -323,7 +323,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
}
template<typename T>
@ -427,7 +427,7 @@ namespace helpers {
coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
};
sd::Threads::parallel_for(func, 0, kTableSize);
samediff::Threads::parallel_for(func, 0, kTableSize);
return coeffs_table;
}
@ -541,7 +541,7 @@ namespace helpers {
x_wai._index3);
}
};
sd::Threads::parallel_for(func, 0, resizer_state.outWidth);
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
} else {
auto func = PRAGMA_THREADS_FOR {
for (auto x = start; x < stop; ++x) {
@ -552,7 +552,7 @@ namespace helpers {
x_wai._index3);
}
};
sd::Threads::parallel_for(func, 0, resizer_state.outWidth);
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
}
// Scale the values so they can be used as offsets into buffers.
auto func = PRAGMA_THREADS_FOR {
@ -563,7 +563,7 @@ namespace helpers {
(*x_wais)[x]._index3 *= resizer_state.channels;
}
};
sd::Threads::parallel_for(func, 0, resizer_state.outWidth);
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
}
template <typename T>
@ -774,7 +774,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_tad(func, 0, batchNum);
samediff::Threads::parallel_tad(func, 0, batchNum);
}
// simplified bicubic resize without antialiasing
@ -950,7 +950,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1);
samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1);
}
template <typename X>
@ -981,7 +981,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1);
samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1);
resizeArea<X>(st, xCached, image, output);
}

View File

@ -45,7 +45,7 @@ static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) {
}
};
sd::Threads::parallel_for(func, 0, output.lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
return;
}
@ -62,7 +62,7 @@ static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) {
}
};
sd::Threads::parallel_for(func, 0, output.lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
return;
}
@ -87,7 +87,7 @@ FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, con
}
};
sd::Threads::parallel_for(func, 0, input.lengthOf(), 3);
samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3);
return;
}
@ -106,7 +106,7 @@ FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, con
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
return;
}
@ -146,7 +146,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
}
else {
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
@ -165,7 +165,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
}
@ -196,7 +196,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
}
else {
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
@ -222,7 +222,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
}

View File

@ -195,7 +195,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
}
};
sd::Threads::parallel_tad(func, 0, tads);
samediff::Threads::parallel_tad(func, 0, tads);
}
}

View File

@ -96,7 +96,7 @@ static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* outpu
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
else {
auto func = PRAGMA_THREADS_FOR {
@ -134,7 +134,7 @@ static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* outpu
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
return Status::OK();
}
@ -242,7 +242,7 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
else {
@ -317,7 +317,7 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c
}
};
sd::Threads::parallel_tad(func, 0, numOfTads);
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
gradI *= gradO;
}

View File

@ -130,7 +130,7 @@ static void fusedTanh(NDArray *z, NDArray *i, NDArray *c, const NDArray *cLast,
}
};
sd::Threads::parallel_for(func, 0, uLen);
samediff::Threads::parallel_for(func, 0, uLen);
}
//////////////////////////////////////////////////////////////////////////

View File

@ -54,7 +54,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(loop, 0, n, 1);
samediff::Threads::parallel_tad(loop, 0, n, 1);
}
}
@ -79,8 +79,8 @@ namespace helpers {
invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
};
sd::Threads::parallel_for(invertDiagonals, 0, n, 1);
sd::Threads::parallel_for(invertSubDiagonals, 1, n, 1);
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1);
// PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int i = 1; i < n; i++) {
@ -118,8 +118,8 @@ namespace helpers {
inputMatrix->t<T>(i, i));
};
sd::Threads::parallel_for(invertDiagonals, 0, n, 1);
sd::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1);
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1);
// PRAGMA_OMP_PARALLEL_FOR_SIMD
for (auto i = n - 2; i >= 0; i--) {
@ -225,7 +225,7 @@ namespace helpers {
}
}
//};
//sd::Threads::parallel_for(loop, column, rowNum, 1);
//samediff::Threads::parallel_for(loop, column, rowNum, 1);
return result;
}
@ -247,7 +247,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
}
template <typename T>
@ -327,7 +327,7 @@ namespace helpers {
luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
}
};
sd::Threads::parallel_for(loop, 0, outputs.size(), 1);
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
}
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {

View File

@ -63,7 +63,7 @@ void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& outp
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
}
};
sd::Threads::parallel_for(func, 0, xLen);
samediff::Threads::parallel_for(func, 0, xLen);
}
//////////////////////////////////////////////////////////////////////////

View File

@ -51,7 +51,7 @@ int _matrixDiagPart(const NDArray* input, NDArray* output) {
listOut.at(i)->p(j, listDiag.at(i)->e<T>(j, j));
};
sd::Threads::parallel_tad(func, 0, lO);
samediff::Threads::parallel_tad(func, 0, lO);
return Status::OK();
}

View File

@ -61,7 +61,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, oL);
samediff::Threads::parallel_for(func, 0, oL);
}
}

View File

@ -67,7 +67,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
} else {
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) {
@ -88,7 +88,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, numTads);
samediff::Threads::parallel_tad(func, 0, numTads);
}
}

View File

@ -80,7 +80,7 @@ static void polyGamma_(sd::LaunchContext * context, const NDArray& n, const NDAr
output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i)));
}
};
sd::Threads::parallel_for(func, 0, x.lengthOf());
samediff::Threads::parallel_for(func, 0, x.lengthOf());
}
void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {

View File

@ -48,7 +48,7 @@ namespace helpers {
resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0));
};
sd::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1);
samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1);
return res;
}
@ -119,7 +119,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(batching, 0, listOutQ.size(), 1);
samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1);
}

View File

@ -197,7 +197,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
rng.rewindH(output.lengthOf()*numOfClassX);
return;

View File

@ -42,7 +42,7 @@ static void _range(const NDArray& start, const NDArray& delta, NDArray& outVecto
for (auto i = start; i < stop; i++)
buff[i] = s + i * d;
};
sd::Threads::parallel_for(func, 0, len);
samediff::Threads::parallel_for(func, 0, len);
}
void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) {

View File

@ -59,7 +59,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
swap(inArr, e, idx);
}
};
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
}
else if (inEWS > 1) {
auto func = PRAGMA_THREADS_FOR {
@ -70,7 +70,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
}
};
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
}
else {
@ -82,7 +82,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
}
};
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
}
}
else {
@ -96,14 +96,14 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
for (Nd4jLong e = start; e < stop; e++)
outArr[sLength - e] = inArr[e];
};
sd::Threads::parallel_for(func, 0, numOfElemsToReverse);
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
if(inLength != numOfElemsToReverse) {
auto f2 = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++)
outArr[e] = inArr[e];
};
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
}
}
else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) {
@ -112,14 +112,14 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
for (auto e = start; e < stop; e++)
outArr[(sLength - e) * outEWS] = inArr[e * inEWS];
};
sd::Threads::parallel_for(func, 0, numOfElemsToReverse);
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
if(inLength != numOfElemsToReverse) {
auto f2 = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++)
outArr[e * outEWS] = inArr[e * inEWS];
};
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
}
}
else {
@ -131,7 +131,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
outArr[outOffset] = inArr[inOffset];
}
};
sd::Threads::parallel_for(func, 0, numOfElemsToReverse);
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
if(inLength != numOfElemsToReverse) {
@ -142,7 +142,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
outArr[outOffset] = inArr[inOffset];
}
};
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
}
}
}

View File

@ -69,7 +69,7 @@ static void batchToSpace_(const NDArray& input, NDArray& output, const uint crop
}
};
sd::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1);
}
BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES);
@ -128,7 +128,7 @@ static void batchToSpaceND_(const NDArray& input, const NDArray& crop, NDArray&
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, (const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);
@ -234,7 +234,7 @@ static void spaceToBatch_(const NDArray& input, NDArray& output, const uint padB
}
};
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
}
BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES);
@ -327,7 +327,7 @@ static void spaceToBatchND_(const NDArray& input, const NDArray& padding, NDArra
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, (const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);

View File

@ -69,7 +69,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, total_count);
samediff::Threads::parallel_for(func, 0, total_count);
} else {
const int total_count = batch_size * output_depth_by_output_area;
@ -93,7 +93,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, total_count);
samediff::Threads::parallel_for(func, 0, total_count);
}
}

View File

@ -58,7 +58,7 @@ Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int
}
};
sd::Threads::parallel_for(func, 0, indices.lengthOf());
samediff::Threads::parallel_for(func, 0, indices.lengthOf());
return numOfBadIndx;
}
@ -87,7 +87,7 @@ void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indic
}
};
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
}
else { // outRank > 1
@ -107,7 +107,7 @@ void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indic
}
};
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
}
}
@ -129,7 +129,7 @@ void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& ind
}
};
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
}
else {
std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1});
@ -154,7 +154,7 @@ void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& ind
}
};
sd::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
samediff::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
}
}
@ -176,7 +176,7 @@ void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray
}
};
sd::Threads::parallel_for(func, 0, indicesLen);
samediff::Threads::parallel_for(func, 0, indicesLen);
} else {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
@ -186,7 +186,7 @@ void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray
}
};
sd::Threads::parallel_for(func, 0, indicesLen);
samediff::Threads::parallel_for(func, 0, indicesLen);
}
}

View File

@ -173,7 +173,7 @@ namespace helpers {
meanV.p<T>(e, meanV.e<T>(e) + listOfTensors.at(i)->e<T>(e));
}
};
sd::Threads::parallel_for(func, 0, meanT->lengthOf());
samediff::Threads::parallel_for(func, 0, meanT->lengthOf());
count++;
}
@ -227,7 +227,7 @@ namespace helpers {
sumT->p(e, sumT->e<T>(e) + listOfTensors.at(i)->e<T>(e));
}
};
sd::Threads::parallel_for(func, 0, sumT->lengthOf());
samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
}
else {
idx = indices->e<int>(i);
@ -276,7 +276,7 @@ namespace helpers {
sumT->p(e, sumT->e<T>(e) * listOfTensors.at(i)->e<T>(e));
}
};
sd::Threads::parallel_for(func, 0, sumT->lengthOf());
samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
}
else {
idx = indices->e<int>(i);
@ -631,7 +631,7 @@ namespace helpers {
output->p(e, gradOut->e<T>(classNum));
}
};
sd::Threads::parallel_for(func, 0, loop_size);
samediff::Threads::parallel_for(func, 0, loop_size);
}
else {
std::vector<int> restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -658,7 +658,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(func, 0, indices->lengthOf());
samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
}
return ND4J_STATUS_OK;
@ -681,7 +681,7 @@ namespace helpers {
output->p(e, gradOut->e<double>(classNum));
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf());
samediff::Threads::parallel_for(func, 0, input->lengthOf());
}
else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -711,7 +711,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(func, 0, indices->lengthOf());
samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
}
return ND4J_STATUS_OK;
}
@ -758,7 +758,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
return ND4J_STATUS_OK;
}
@ -791,7 +791,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
return Status::OK();
}
@ -828,7 +828,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
return ND4J_STATUS_OK;
@ -894,7 +894,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, input->lengthOf());
samediff::Threads::parallel_for(func, 0, input->lengthOf());
}
else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -918,7 +918,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
return ND4J_STATUS_OK;
@ -993,7 +993,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
return Status::OK();
}
@ -1010,7 +1010,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, indices->lengthOf());
samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -1032,7 +1032,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
return Status::OK();
@ -1059,7 +1059,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
else {
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -1081,7 +1081,7 @@ namespace helpers {
}
//};
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
return Status::OK();
}

View File

@ -34,7 +34,7 @@ namespace helpers {
output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
};
sd::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1);
samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1);
}
void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {

View File

@ -425,7 +425,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
}
BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads), FLOAT_TYPES);
@ -577,7 +577,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
}
BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES);

View File

@ -136,7 +136,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func,0, numOfSubArrs);
samediff::Threads::parallel_tad(func,0, numOfSubArrs);
}
#endif
@ -168,7 +168,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func,0, numOfSubArrs);
samediff::Threads::parallel_tad(func,0, numOfSubArrs);
}
//////////////////////////////////////////////////////////////////////////
@ -228,7 +228,7 @@ namespace sd {
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
delete []offsets;
}

View File

@ -48,7 +48,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
}
// --------------------------------------------------------------------------------------------------------------------------------------- //

View File

@ -115,7 +115,7 @@ namespace helpers {
}
};
sd::Threads::parallel_for(func, 0, input.lengthOf());
samediff::Threads::parallel_for(func, 0, input.lengthOf());
}
void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis) {

View File

@ -184,7 +184,7 @@ static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray
}
};
sd::Threads::parallel_tad(func, 0, ncols);
samediff::Threads::parallel_tad(func, 0, ncols);
}
//////////////////////////////////////////////////////////////////////////
@ -303,7 +303,7 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr
}
};
sd::Threads::parallel_tad(func, 0, ncols);
samediff::Threads::parallel_tad(func, 0, ncols);
// gradB
gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K]

View File

@ -43,7 +43,7 @@ static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, c
output.p<T>(i, inArrs[i]->t<T>(0));
};
sd::Threads::parallel_for(func, 0, numOfSubArrs);
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
}
else {
@ -63,7 +63,7 @@ static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, c
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}
@ -88,7 +88,7 @@ static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs,
outArrs[i]->p<T>(0, input.t<T>(i));
};
sd::Threads::parallel_for(func, 0, numOfSubArrs);
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
}
else {
@ -107,7 +107,7 @@ static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs,
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}

View File

@ -163,7 +163,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(func, 0, target->lengthOf());
samediff::Threads::parallel_tad(func, 0, target->lengthOf());
}
return status;

View File

@ -48,7 +48,7 @@ static void triuBP_(sd::LaunchContext * context, const NDArray& input, const NDA
dOdI.t<T>(i) = static_cast<T>(1.f);
}
};
sd::Threads::parallel_for(func, 0, dLen);
samediff::Threads::parallel_for(func, 0, dLen);
// FIXME: !!!
gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO
@ -68,7 +68,7 @@ static void trace_(const NDArray& input, NDArray& output) {
for (auto i = start; i < stop; i++)
output.p(i, setOfSubArrs.at(i)->getTrace());
};
sd::Threads::parallel_for(func, 0, setOfSubArrs.size());
samediff::Threads::parallel_for(func, 0, setOfSubArrs.size());
}
void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) {
@ -211,7 +211,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
else { // REFLECT and SYMMETRIC cases
@ -237,7 +237,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
}
@ -606,7 +606,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
}
};
sd::Threads::parallel_tad(func, 0, zLen);
samediff::Threads::parallel_tad(func, 0, zLen);
}
////////////////////////////////////////////////////////////////////////
@ -654,7 +654,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
output->p(e, input->e<T>(indices->e<Nd4jLong>(e)));
};
sd::Threads::parallel_for(func, 0, indices->lengthOf());
samediff::Threads::parallel_for(func, 0, indices->lengthOf());
}
else {
@ -670,7 +670,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}
else {
@ -694,7 +694,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
}
};
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}
}
@ -714,7 +714,7 @@ void eye(sd::LaunchContext * context, NDArray& output) {
arrs.at(i)->setIdentity();
};
sd::Threads::parallel_tad(func, 0, arrs.size());
samediff::Threads::parallel_tad(func, 0, arrs.size());
}
//////////////////////////////////////////////////////////////////////////
@ -772,7 +772,7 @@ void scatterUpdate(sd::LaunchContext * context, NDArray& input, NDArray& updates
}
};
sd::Threads::parallel_tad(func, 0, indices.size());
samediff::Threads::parallel_tad(func, 0, indices.size());
}
@ -792,7 +792,7 @@ void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input,
}
};
sd::Threads::parallel_for(func, 0, len);
samediff::Threads::parallel_for(func, 0, len);
}
break;
@ -824,7 +824,7 @@ static void mergeMaxIndex_(const std::vector<NDArray*>& inArrs, NDArray& output)
}
};
sd::Threads::parallel_for(func, 0, x->lengthOf());
samediff::Threads::parallel_for(func, 0, x->lengthOf());
}
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
@ -850,7 +850,7 @@ static void mergeMax_(const std::vector<NDArray*>& inArrs, NDArray& output) {
}
};
sd::Threads::parallel_for(func, 0, x->lengthOf());
samediff::Threads::parallel_for(func, 0, x->lengthOf());
}
void mergeMax(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
@ -875,7 +875,7 @@ static void mergeAvg_(const std::vector<NDArray*>& inArrs, NDArray& output) {
}
};
sd::Threads::parallel_for(func, 0, x->lengthOf());
samediff::Threads::parallel_for(func, 0, x->lengthOf());
}
void mergeAvg(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
@ -900,7 +900,7 @@ static void mergeAdd_(const std::vector<NDArray*>& inArrs, NDArray& output) {
}
};
sd::Threads::parallel_for(func, 0, x->lengthOf());
samediff::Threads::parallel_for(func, 0, x->lengthOf());
}
void mergeAdd(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES);
@ -934,7 +934,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>&
*listOfInSubArrs.at(i) *= normClip / iNormActual;
}
};
sd::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
}
}
else {
@ -963,7 +963,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>&
*outputSubArr *= clipNorm / iNormActual;
}
};
sd::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
}
}
}
@ -1079,7 +1079,7 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g
gradISubArr->assign(gradOSubArr);
}
};
sd::Threads::parallel_tad(func, 0, gradISubArrs.size());
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
}
}
@ -1215,7 +1215,7 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o
}
};
sd::Threads::parallel_for(func, 0, outLen);
samediff::Threads::parallel_for(func, 0, outLen);
}
}

View File

@ -99,7 +99,7 @@ namespace helpers {
}
};
sd::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1);
samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1);
return Status::OK();
@ -128,7 +128,7 @@ namespace helpers {
}
}
};
sd::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
}
int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) {

View File

@ -68,7 +68,7 @@ static void zeta_(sd::LaunchContext * context, const NDArray& x, const NDArray&
z.p(i, zetaScalar<T>(x.e<T>(i), q.e<T>(i)));
};
sd::Threads::parallel_for(func, 0, xLen);
samediff::Threads::parallel_for(func, 0, xLen);
}
void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) {

View File

@ -77,7 +77,7 @@ void FORCEINLINE cross(sd::LaunchContext * context, NDArray *a, NDArray *b, NDAr
}
};
sd::Threads::parallel_tad(func, 0, tads);
samediff::Threads::parallel_tad(func, 0, tads);
}
void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output);

Some files were not shown because too many files have changed in this diff Show More