OpenMP Threads execution (#297)
* omp threads backported Signed-off-by: raver119 <raver119@gmail.com> * omp scalar reduce Signed-off-by: raver119 <raver119@gmail.com> * timing Signed-off-by: raver119 <raver119@gmail.com> * timing Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * namespace change Signed-off-by: raver119 <raver119@gmail.com> * num_threads Signed-off-by: raver119 <raver119@gmail.com> * one minor fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
a2ec3dbc97
commit
dd2043ef48
|
@ -21,9 +21,9 @@ if (SD_CUDA)
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
set(CMAKE_CUDA_STANDARD 11)
|
set(CMAKE_CUDA_STANDARD 11)
|
||||||
|
|
||||||
set(DEFAULT_ENGINE "samediff::ENGINE_CUDA")
|
set(DEFAULT_ENGINE "sd::ENGINE_CUDA")
|
||||||
else()
|
else()
|
||||||
set(DEFAULT_ENGINE "samediff::ENGINE_CPU")
|
set(DEFAULT_ENGINE "sd::ENGINE_CPU")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively
|
# MSVC runtime lib can be either "MultiThreaded" or "MultiThreadedDLL", /MT and /MD respectively
|
||||||
|
|
|
@ -56,7 +56,7 @@ namespace sd {
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length);
|
sd::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
|
@ -114,7 +114,7 @@ namespace sd {
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length);
|
sd::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
|
@ -142,7 +142,7 @@ namespace sd {
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length);
|
sd::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
}
|
}
|
||||||
|
@ -168,7 +168,7 @@ namespace sd {
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length);
|
sd::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
}
|
}
|
||||||
|
|
|
@ -515,7 +515,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
tickWriteHost();
|
tickWriteHost();
|
||||||
syncToDevice();
|
syncToDevice();
|
||||||
|
@ -582,7 +582,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::stri
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
|
|
||||||
tickWriteHost();
|
tickWriteHost();
|
||||||
|
@ -648,7 +648,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
tickWriteHost();
|
tickWriteHost();
|
||||||
syncToDevice();
|
syncToDevice();
|
||||||
|
@ -714,7 +714,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
tickWriteHost();
|
tickWriteHost();
|
||||||
syncToDevice();
|
syncToDevice();
|
||||||
|
@ -780,7 +780,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
tickWriteHost();
|
tickWriteHost();
|
||||||
syncToDevice();
|
syncToDevice();
|
||||||
|
@ -846,7 +846,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
tickWriteHost();
|
tickWriteHost();
|
||||||
syncToDevice();
|
syncToDevice();
|
||||||
|
@ -2393,7 +2393,7 @@ NDArray NDArray::asS() const {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
registerPrimaryUse({ &res }, { this });
|
registerPrimaryUse({ &res }, { this });
|
||||||
|
|
||||||
|
@ -3466,7 +3466,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
|
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
|
||||||
}
|
}
|
||||||
|
@ -3479,7 +3479,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
|
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
|
||||||
}
|
}
|
||||||
|
@ -3491,7 +3491,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||||
|
|
||||||
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
|
return NDArray(getShapeAsVector(), strings, dataType(), getContext());
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,7 +115,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
sd::Threads::parallel_for(func, 0, zLen);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length);
|
sd::Threads::parallel_for(func, 0, length);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, resultLen);
|
sd::Threads::parallel_for(func, 0, resultLen);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -284,7 +284,7 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, resultLen);
|
sd::Threads::parallel_for(func, 0, resultLen);
|
||||||
}
|
}
|
||||||
result.tickWriteHost();
|
result.tickWriteHost();
|
||||||
return result;
|
return result;
|
||||||
|
@ -397,7 +397,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
sd::Threads::parallel_for(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -26,7 +26,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
||||||
z[e] = func(f[e], s[e], t[e]);
|
z[e] = func(f[e], s[e], t[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
|
@ -54,7 +54,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -97,7 +97,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
||||||
z[e] = func(f[e], s[e]);
|
z[e] = func(f[e], s[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
|
@ -123,7 +123,7 @@ void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,7 +160,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
||||||
z[e] = func(f[e]);
|
z[e] = func(f[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
|
@ -184,7 +184,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -221,7 +221,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
||||||
z[e] = func(e, f[e]);
|
z[e] = func(e, f[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
|
@ -233,7 +233,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
|
@ -245,7 +245,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -287,7 +287,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
||||||
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
|
@ -300,7 +300,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
auto loop = PRAGMA_THREADS_FOR {
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
|
@ -313,7 +313,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(N
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(loop, 0, _length);
|
sd::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class BlockingQueue {
|
class BlockingQueue {
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
/**
|
/**
|
||||||
* This class is suited for passing functions to execution threads without queues
|
* This class is suited for passing functions to execution threads without queues
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
class CallableWithArguments {
|
class CallableWithArguments {
|
||||||
FUNC_DO _function_do;
|
FUNC_DO _function_do;
|
||||||
FUNC_1D _function_1d;
|
FUNC_1D _function_1d;
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef SD_ENGINE_H
|
#ifndef SD_ENGINE_H
|
||||||
#define SD_ENGINE_H
|
#define SD_ENGINE_H
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
enum Engine {
|
enum Engine {
|
||||||
ENGINE_CPU = 0,
|
ENGINE_CPU = 0,
|
||||||
ENGINE_CUDA = 1,
|
ENGINE_CUDA = 1,
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef SD_EXECUTIONMODE_H
|
#ifndef SD_EXECUTIONMODE_H
|
||||||
#define SD_EXECUTIONMODE_H
|
#define SD_EXECUTIONMODE_H
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
enum ExecutionMode {
|
enum ExecutionMode {
|
||||||
MODE_UNDEFINED = 0,
|
MODE_UNDEFINED = 0,
|
||||||
MODE_TRAINING = 1,
|
MODE_TRAINING = 1,
|
||||||
|
|
|
@ -32,7 +32,7 @@
|
||||||
#include <execution/Ticket.h>
|
#include <execution/Ticket.h>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
class ND4J_EXPORT ThreadPool {
|
class ND4J_EXPORT ThreadPool {
|
||||||
private:
|
private:
|
||||||
static ThreadPool* _INSTANCE;
|
static ThreadPool* _INSTANCE;
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
#ifndef SAMEDIFF_THREADS_H
|
#ifndef SAMEDIFF_THREADS_H
|
||||||
#define SAMEDIFF_THREADS_H
|
#define SAMEDIFF_THREADS_H
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include <system/Environment.h>
|
#include <system/Environment.h>
|
||||||
#include <system/op_enums.h>
|
#include <system/op_enums.h>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
class ND4J_EXPORT ThreadsHelper {
|
class ND4J_EXPORT ThreadsHelper {
|
||||||
public:
|
public:
|
||||||
static int numberOfThreads(int maxThreads, uint64_t numberOfElements);
|
static int numberOfThreads(int maxThreads, uint64_t numberOfElements);
|
||||||
|
@ -95,6 +95,14 @@ namespace samediff {
|
||||||
};
|
};
|
||||||
|
|
||||||
class ND4J_EXPORT Threads {
|
class ND4J_EXPORT Threads {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
public:
|
||||||
|
static std::mutex gThreadmutex;
|
||||||
|
static uint64_t _nFreeThreads;
|
||||||
|
static bool tryAcquire(int numThreads);
|
||||||
|
static bool freeThreads(int numThreads);
|
||||||
|
#endif
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* This function executes 1 dimensional loop for a given number of threads
|
* This function executes 1 dimensional loop for a given number of threads
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
class ND4J_EXPORT Ticket {
|
class ND4J_EXPORT Ticket {
|
||||||
private:
|
private:
|
||||||
bool _acquired = false;
|
bool _acquired = false;
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
#include <execution/CallableWithArguments.h>
|
#include <execution/CallableWithArguments.h>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
BlockingQueue<T>::BlockingQueue(int queueSize) {
|
BlockingQueue<T>::BlockingQueue(int queueSize) {
|
||||||
_size = 0;
|
_size = 0;
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include <execution/CallableInterface.h>
|
#include <execution/CallableInterface.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
CallableInterface::CallableInterface() {
|
CallableInterface::CallableInterface() {
|
||||||
// initial state is available
|
// initial state is available
|
||||||
_available = true;
|
_available = true;
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
#include <execution/CallableWithArguments.h>
|
#include <execution/CallableWithArguments.h>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) {
|
CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) {
|
||||||
_function_do = func;
|
_function_do = func;
|
||||||
_finished = false;
|
_finished = false;
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
//#include <windows.h>
|
//#include <windows.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
|
|
||||||
// this function executed once per thread, it polls functions from queue, and executes them via wrapper
|
// this function executed once per thread, it polls functions from queue, and executes them via wrapper
|
||||||
static void executionLoop_(int thread_id, BlockingQueue<CallableWithArguments*> *queue) {
|
static void executionLoop_(int thread_id, BlockingQueue<CallableWithArguments*> *queue) {
|
||||||
|
@ -183,7 +183,7 @@ namespace samediff {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ThreadPool::release(samediff::Ticket *ticket) {
|
void ThreadPool::release(sd::Ticket *ticket) {
|
||||||
// returning ticket back to the queue
|
// returning ticket back to the queue
|
||||||
std::unique_lock<std::mutex> lock(_lock);
|
std::unique_lock<std::mutex> lock(_lock);
|
||||||
_tickets.push(ticket);
|
_tickets.push(ticket);
|
||||||
|
|
|
@ -25,8 +25,14 @@
|
||||||
#include <math/templatemath.h>
|
#include <math/templatemath.h>
|
||||||
#include <helpers/shape.h>
|
#include <helpers/shape.h>
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
namespace samediff {
|
#include <omp.h>
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
|
||||||
int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) {
|
int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) {
|
||||||
// let's see how many threads we actually need first
|
// let's see how many threads we actually need first
|
||||||
|
@ -51,34 +57,34 @@ namespace samediff {
|
||||||
Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) {
|
Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) {
|
||||||
switch (loop) {
|
switch (loop) {
|
||||||
case 1: {
|
case 1: {
|
||||||
auto span = (stopX - startX) / numThreads;
|
auto span = (stopX - startX) / numThreads;
|
||||||
auto s = span * threadID;
|
auto s = span * threadID;
|
||||||
auto e = s + span;
|
auto e = s + span;
|
||||||
if (threadID == numThreads - 1)
|
if (threadID == numThreads - 1)
|
||||||
e = stopX;
|
e = stopX;
|
||||||
|
|
||||||
return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 2: {
|
case 2: {
|
||||||
auto span = (stopY - startY) / numThreads;
|
auto span = (stopY - startY) / numThreads;
|
||||||
auto s = span * threadID;
|
auto s = span * threadID;
|
||||||
auto e = s + span;
|
auto e = s + span;
|
||||||
if (threadID == numThreads - 1)
|
if (threadID == numThreads - 1)
|
||||||
e = stopY;
|
e = stopY;
|
||||||
|
|
||||||
return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ);
|
return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 3: {
|
case 3: {
|
||||||
auto span = (stopZ - startZ) / numThreads;
|
auto span = (stopZ - startZ) / numThreads;
|
||||||
auto s = span * threadID;
|
auto s = span * threadID;
|
||||||
auto e = s + span;
|
auto e = s + span;
|
||||||
if (threadID == numThreads - 1)
|
if (threadID == numThreads - 1)
|
||||||
e = stopZ;
|
e = stopZ;
|
||||||
|
|
||||||
return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ);
|
return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("");
|
throw std::runtime_error("");
|
||||||
|
@ -116,24 +122,24 @@ namespace samediff {
|
||||||
|
|
||||||
switch (loop) {
|
switch (loop) {
|
||||||
case 1: {
|
case 1: {
|
||||||
auto span = (stopX - startX) / numThreads;
|
auto span = (stopX - startX) / numThreads;
|
||||||
auto s = span * threadID;
|
auto s = span * threadID;
|
||||||
auto e = s + span;
|
auto e = s + span;
|
||||||
if (threadID == numThreads - 1)
|
if (threadID == numThreads - 1)
|
||||||
e = stopX;
|
e = stopX;
|
||||||
|
|
||||||
return Span2(s, e, incX, startY, stopY, incY);
|
return Span2(s, e, incX, startY, stopY, incY);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 2: {
|
case 2: {
|
||||||
auto span = (stopY - startY) / numThreads;
|
auto span = (stopY - startY) / numThreads;
|
||||||
auto s = span * threadID;
|
auto s = span * threadID;
|
||||||
auto e = s + span;
|
auto e = s + span;
|
||||||
if (threadID == numThreads - 1)
|
if (threadID == numThreads - 1)
|
||||||
e = stopY;
|
e = stopY;
|
||||||
|
|
||||||
return Span2(startX, stopX, incX, s, e, incY);
|
return Span2(startX, stopX, incX, s, e, incY);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("");
|
throw std::runtime_error("");
|
||||||
|
@ -270,7 +276,7 @@ namespace samediff {
|
||||||
auto remY = iters_y % maxThreads;
|
auto remY = iters_y % maxThreads;
|
||||||
|
|
||||||
// in some cases there's nothing to think about, part 2
|
// in some cases there's nothing to think about, part 2
|
||||||
if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0))
|
if ((iters_x >= maxThreads && remX == 0) || (iters_y >= maxThreads && remY == 0))
|
||||||
return maxThreads;
|
return maxThreads;
|
||||||
|
|
||||||
// at this point we suppose that there's no loop perfectly matches number of our threads
|
// at this point we suppose that there's no loop perfectly matches number of our threads
|
||||||
|
@ -339,11 +345,35 @@ namespace samediff {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
std::mutex Threads::gThreadmutex;
|
||||||
|
uint64_t Threads::_nFreeThreads = sd::Environment::getInstance()->maxThreads();
|
||||||
|
|
||||||
|
bool Threads::tryAcquire(int numThreads){
|
||||||
|
std::lock_guard<std::mutex> lock( gThreadmutex );
|
||||||
|
auto nThreads = _nFreeThreads - numThreads;
|
||||||
|
if(nThreads >= 1){
|
||||||
|
_nFreeThreads = nThreads;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Threads::freeThreads(int numThreads){
|
||||||
|
std::lock_guard<std::mutex> lock( gThreadmutex );
|
||||||
|
_nFreeThreads += numThreads;
|
||||||
|
// check if correct number of threads
|
||||||
|
return _nFreeThreads > sd::Environment::getInstance()->maxThreads();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
|
int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
|
||||||
if (start > stop)
|
if (start > stop)
|
||||||
throw std::runtime_error("Threads::parallel_for got start > stop");
|
throw std::runtime_error("Threads::parallel_for got start > stop");
|
||||||
|
|
||||||
auto delta = (stop - start);
|
auto delta = (stop - start) / increment;
|
||||||
|
|
||||||
if (numThreads > delta)
|
if (numThreads > delta)
|
||||||
numThreads = delta;
|
numThreads = delta;
|
||||||
|
@ -357,35 +387,57 @@ namespace samediff {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
#ifdef _OPENMP
|
||||||
if (ticket != nullptr) {
|
|
||||||
// if we got our threads - we'll run our jobs here
|
|
||||||
auto span = delta / numThreads;
|
|
||||||
|
|
||||||
for (uint32_t e = 0; e < numThreads; e++) {
|
if (tryAcquire(numThreads)) {
|
||||||
auto start_ = span * e + start;
|
#pragma omp parallel for num_threads(numThreads)
|
||||||
auto stop_ = start_ + span;
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
function(omp_get_thread_num(), e, e + 1, 1);
|
||||||
// last thread will process tail
|
|
||||||
if (e == numThreads - 1)
|
|
||||||
stop_ = stop;
|
|
||||||
|
|
||||||
// putting the task into the queue for a given thread
|
|
||||||
ticket->enqueue(e, numThreads, function, start_, stop_, increment);
|
|
||||||
}
|
}
|
||||||
|
freeThreads(numThreads);
|
||||||
// block and wait till all threads finished the job
|
|
||||||
ticket->waitAndRelease();
|
|
||||||
|
|
||||||
// we tell that parallelism request succeeded
|
|
||||||
return numThreads;
|
return numThreads;
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
// if there were no threads available - we'll execute function right within current thread
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
function(0, start, stop, increment);
|
function(0, start, stop, increment);
|
||||||
|
|
||||||
// we tell that parallelism request declined
|
// we tell that parallelism request declined
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
sd::Environment::getInstance()->maxThreads();
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
// if we got our threads - we'll run our jobs here
|
||||||
|
auto span = delta / numThreads;
|
||||||
|
|
||||||
|
for (uint32_t e = 0; e < numThreads; e++) {
|
||||||
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = start_ + span;
|
||||||
|
|
||||||
|
// last thread will process tail
|
||||||
|
if (e == numThreads - 1)
|
||||||
|
stop_ = stop;
|
||||||
|
|
||||||
|
// putting the task into the queue for a given thread
|
||||||
|
ticket->enqueue(e, numThreads, function, start_, stop_, increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// block and wait till all threads finished the job
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
// we tell that parallelism request succeeded
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, start, stop, increment);
|
||||||
|
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
|
int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
|
||||||
|
@ -448,28 +500,53 @@ namespace samediff {
|
||||||
|
|
||||||
// but we still mimic multithreaded execution
|
// but we still mimic multithreaded execution
|
||||||
return numThreads;
|
return numThreads;
|
||||||
} else {
|
}
|
||||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
else {
|
||||||
if (ticket != nullptr) {
|
#ifdef _OPENMP
|
||||||
|
|
||||||
for (int e = 0; e < numThreads; e++) {
|
if (tryAcquire(numThreads)) {
|
||||||
auto threadId = numThreads - e - 1;
|
#pragma omp parallel for num_threads(numThreads) collapse(2)
|
||||||
auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY);
|
for (auto x = startX; x < stopX; x += incX) {
|
||||||
|
for (auto y = startY; y < stopY; y += incY) {
|
||||||
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
|
function(omp_get_thread_num(), x, x+1, 1, y, y+1, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
freeThreads(numThreads);
|
||||||
// block until all threads finish their job
|
|
||||||
ticket->waitAndRelease();
|
|
||||||
|
|
||||||
return numThreads;
|
return numThreads;
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
// if there were no threads available - we'll execute function right within current thread
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
function(0, startX, stopX, incX, startY, stopY, incY);
|
function(0, startX, stopX, incX, startY, stopY, incY);
|
||||||
|
|
||||||
// we tell that parallelism request declined
|
// we tell that parallelism request declined
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
|
||||||
|
for (int e = 0; e < numThreads; e++) {
|
||||||
|
auto threadId = numThreads - e - 1;
|
||||||
|
auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY);
|
||||||
|
|
||||||
|
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
|
||||||
|
}
|
||||||
|
|
||||||
|
// block until all threads finish their job
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY);
|
||||||
|
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -484,6 +561,35 @@ namespace samediff {
|
||||||
if (startZ > stopZ)
|
if (startZ > stopZ)
|
||||||
throw std::runtime_error("Threads::parallel_for got startZ > stopZ");
|
throw std::runtime_error("Threads::parallel_for got startZ > stopZ");
|
||||||
|
|
||||||
|
if (numThreads == 1) {
|
||||||
|
// loop is too small - executing function as is
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
if (tryAcquire(numThreads)) {
|
||||||
|
#pragma omp parallel for num_threads(numThreads) collapse(3)
|
||||||
|
for (auto x = startX; x < stopX; x += incX) {
|
||||||
|
for (auto y = startY; y < stopY; y += incY) {
|
||||||
|
for (auto z = startZ; z < stopZ; z += incZ) {
|
||||||
|
function(omp_get_thread_num(), x, x+1, 1, y, y+1, 1, z, z+1, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
freeThreads(numThreads);
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
#else
|
||||||
auto delta_x = stopX - startX;
|
auto delta_x = stopX - startX;
|
||||||
auto delta_y = stopY - startY;
|
auto delta_y = stopY - startY;
|
||||||
auto delta_z = stopZ - startZ;
|
auto delta_z = stopZ - startZ;
|
||||||
|
@ -500,52 +606,79 @@ namespace samediff {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||||
if (ticket != nullptr) {
|
if (ticket != nullptr) {
|
||||||
auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ);
|
auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ);
|
||||||
|
|
||||||
for (int e = 0; e < numThreads; e++) {
|
for (int e = 0; e < numThreads; e++) {
|
||||||
auto thread_id = numThreads - e - 1;
|
auto thread_id = numThreads - e - 1;
|
||||||
auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
|
||||||
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ());
|
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ());
|
||||||
}
|
}
|
||||||
|
|
||||||
// block until we're done
|
// block until we're done
|
||||||
ticket->waitAndRelease();
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
// we tell that parallelism request succeeded
|
// we tell that parallelism request succeeded
|
||||||
return numThreads;
|
return numThreads;
|
||||||
} else {
|
}
|
||||||
// if there were no threads available - we'll execute function right within current thread
|
else {
|
||||||
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
// we tell that parallelism request declined
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
|
int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
|
||||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
|
||||||
if (ticket != nullptr) {
|
|
||||||
|
|
||||||
// submit tasks one by one
|
if (numThreads == 1) {
|
||||||
for (uint64_t e = 0; e < numThreads - 1; e++)
|
function(0, numThreads);
|
||||||
ticket->enqueue(e, numThreads, function);
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
function(numThreads - 1, numThreads);
|
#ifdef _OPENMP
|
||||||
|
|
||||||
ticket->waitAndRelease();
|
if (tryAcquire(numThreads)) {
|
||||||
|
#pragma omp parallel for num_threads(numThreads)
|
||||||
|
for (int e = 0; e < numThreads; e++) {
|
||||||
|
function(e, numThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
freeThreads(numThreads);
|
||||||
return numThreads;
|
return numThreads;
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
// if there's no threads available - we'll execute function sequentially one by one
|
// if there's no threads available - we'll execute function sequentially one by one
|
||||||
for (uint64_t e = 0; e < numThreads; e++)
|
for (uint64_t e = 0; e < numThreads; e++)
|
||||||
function(e, numThreads);
|
function(e, numThreads);
|
||||||
|
|
||||||
return numThreads;
|
return numThreads;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
|
||||||
|
// submit tasks one by one
|
||||||
|
for (uint64_t e = 0; e < numThreads - 1; e++)
|
||||||
|
ticket->enqueue(e, numThreads, function);
|
||||||
|
|
||||||
|
function(numThreads - 1, numThreads);
|
||||||
|
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// if there's no threads available - we'll execute function sequentially one by one
|
||||||
|
for (uint64_t e = 0; e < numThreads; e++)
|
||||||
|
function(e, numThreads);
|
||||||
|
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
return numThreads;
|
return numThreads;
|
||||||
}
|
}
|
||||||
|
@ -565,26 +698,44 @@ namespace samediff {
|
||||||
if (numThreads == 1)
|
if (numThreads == 1)
|
||||||
return function(0, start, stop, increment);
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
|
||||||
if (ticket == nullptr)
|
|
||||||
return function(0, start, stop, increment);
|
|
||||||
|
|
||||||
// create temporary array
|
// create temporary array
|
||||||
int64_t intermediatery[256];
|
int64_t intermediatery[256];
|
||||||
auto span = delta / numThreads;
|
auto span = delta / numThreads;
|
||||||
|
|
||||||
// execute threads in parallel
|
#ifdef _OPENMP
|
||||||
for (uint32_t e = 0; e < numThreads; e++) {
|
if (tryAcquire(numThreads)) {
|
||||||
auto start_ = span * e + start;
|
#pragma omp parallel for num_threads(numThreads)
|
||||||
auto stop_ = span * (e + 1) + start;
|
for (int e = 0; e < numThreads; e++) {
|
||||||
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = span * (e + 1) + start;
|
||||||
|
|
||||||
if (e == numThreads - 1)
|
intermediatery[e] = function(e, start_, e == numThreads - 1 ? stop : stop_, increment);
|
||||||
intermediatery[e] = function(e, start_, stop, increment);
|
}
|
||||||
else
|
freeThreads(numThreads);
|
||||||
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
|
|
||||||
}
|
}
|
||||||
|
else{
|
||||||
|
// if there were no thre ads available - we'll execute function right within current thread
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||||
|
if (ticket == nullptr)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
ticket->waitAndRelease();
|
// execute threads in parallel
|
||||||
|
for (uint32_t e = 0; e < numThreads; e++) {
|
||||||
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = span * (e + 1) + start;
|
||||||
|
|
||||||
|
if (e == numThreads - 1)
|
||||||
|
intermediatery[e] = function(e, start_, stop, increment);
|
||||||
|
else
|
||||||
|
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// aggregate results in single thread
|
// aggregate results in single thread
|
||||||
for (uint64_t e = 1; e < numThreads; e++)
|
for (uint64_t e = 1; e < numThreads; e++)
|
||||||
|
@ -609,26 +760,47 @@ namespace samediff {
|
||||||
if (numThreads == 1)
|
if (numThreads == 1)
|
||||||
return function(0, start, stop, increment);
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
|
||||||
if (ticket == nullptr)
|
|
||||||
return function(0, start, stop, increment);
|
|
||||||
|
|
||||||
// create temporary array
|
// create temporary array
|
||||||
double intermediatery[256];
|
double intermediatery[256];
|
||||||
auto span = delta / numThreads;
|
auto span = delta / numThreads;
|
||||||
|
|
||||||
// execute threads in parallel
|
#ifdef _OPENMP
|
||||||
for (uint32_t e = 0; e < numThreads; e++) {
|
|
||||||
auto start_ = span * e + start;
|
|
||||||
auto stop_ = span * (e + 1) + start;
|
|
||||||
|
|
||||||
if (e == numThreads - 1)
|
if (tryAcquire(numThreads)) {
|
||||||
intermediatery[e] = function(e, start_, stop, increment);
|
#pragma omp parallel for num_threads(numThreads)
|
||||||
else
|
for (int e = 0; e < numThreads; e++) {
|
||||||
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = span * (e + 1) + start;
|
||||||
|
|
||||||
|
intermediatery[e] = function(e, start_, e == numThreads - 1 ? stop : stop_, increment);
|
||||||
|
}
|
||||||
|
freeThreads(numThreads);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// if there were no thre ads available - we'll execute function right within current thread
|
||||||
|
return function(0, start, stop, increment);
|
||||||
}
|
}
|
||||||
|
|
||||||
ticket->waitAndRelease();
|
#else
|
||||||
|
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||||
|
if (ticket == nullptr)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
|
// execute threads in parallel
|
||||||
|
for (uint32_t e = 0; e < numThreads; e++) {
|
||||||
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = span * (e + 1) + start;
|
||||||
|
|
||||||
|
if (e == numThreads - 1)
|
||||||
|
intermediatery[e] = function(e, start_, stop, increment);
|
||||||
|
else
|
||||||
|
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// aggregate results in single thread
|
// aggregate results in single thread
|
||||||
for (uint64_t e = 1; e < numThreads; e++)
|
for (uint64_t e = 1; e < numThreads; e++)
|
||||||
|
@ -639,7 +811,7 @@ namespace samediff {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size , uint32_t req_numThreads) {
|
int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size, uint32_t req_numThreads) {
|
||||||
if (start > stop)
|
if (start > stop)
|
||||||
throw std::runtime_error("Threads::parallel_for got start > stop");
|
throw std::runtime_error("Threads::parallel_for got start > stop");
|
||||||
auto num_elements = (stop - start);
|
auto num_elements = (stop - start);
|
||||||
|
@ -647,6 +819,7 @@ namespace samediff {
|
||||||
//so we will parition considering delta but not total elements
|
//so we will parition considering delta but not total elements
|
||||||
auto delta = (stop - start) / increment;
|
auto delta = (stop - start) / increment;
|
||||||
|
|
||||||
|
|
||||||
// in some cases we just fire func as is
|
// in some cases we just fire func as is
|
||||||
if (delta == 0 || req_numThreads == 1) {
|
if (delta == 0 || req_numThreads == 1) {
|
||||||
function(0, start, stop, increment);
|
function(0, start, stop, increment);
|
||||||
|
@ -654,7 +827,24 @@ namespace samediff {
|
||||||
}
|
}
|
||||||
int numThreads = 0;
|
int numThreads = 0;
|
||||||
|
|
||||||
int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size));
|
struct th_span {
|
||||||
|
Nd4jLong start;
|
||||||
|
Nd4jLong end;
|
||||||
|
};
|
||||||
|
#ifdef _OPENMP
|
||||||
|
constexpr int max_thread_count = 8;
|
||||||
|
#else
|
||||||
|
constexpr int max_thread_count = 1024;
|
||||||
|
#endif
|
||||||
|
th_span thread_spans[max_thread_count];
|
||||||
|
|
||||||
|
req_numThreads = req_numThreads > max_thread_count ? max_thread_count : req_numThreads;
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
int adjusted_numThreads = max_thread_count;
|
||||||
|
#else
|
||||||
|
int adjusted_numThreads = sd::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size));
|
||||||
|
#endif
|
||||||
|
|
||||||
if (adjusted_numThreads > delta)
|
if (adjusted_numThreads > delta)
|
||||||
adjusted_numThreads = delta;
|
adjusted_numThreads = delta;
|
||||||
|
@ -663,61 +853,89 @@ namespace samediff {
|
||||||
function(0, start, stop, increment);
|
function(0, start, stop, increment);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//take span as ceil
|
//take span as ceil
|
||||||
auto spand = std::ceil((double)delta / (double)adjusted_numThreads);
|
auto spand = std::ceil((double)delta / (double)adjusted_numThreads);
|
||||||
numThreads = static_cast<int>(std::ceil((double)delta / spand));
|
numThreads = static_cast<int>(std::ceil((double)delta / spand));
|
||||||
auto span = static_cast<Nd4jLong>(spand);
|
auto span = static_cast<Nd4jLong>(spand);
|
||||||
|
|
||||||
auto ticket = samediff::ThreadPool::getInstance()->tryAcquire(numThreads);
|
|
||||||
if (ticket != nullptr) {
|
|
||||||
//tail_add is additional value of the last part
|
|
||||||
//it could be negative or positive
|
|
||||||
//we will spread that value across
|
|
||||||
auto tail_add = delta - numThreads * span;
|
|
||||||
Nd4jLong begin = 0;
|
|
||||||
Nd4jLong end = 0;
|
|
||||||
|
|
||||||
//we will try enqueu bigger parts first
|
//tail_add is additional value of the last part
|
||||||
decltype(span) span1, span2;
|
//it could be negative or positive
|
||||||
int last = 0;
|
//we will spread that value across
|
||||||
if (tail_add >= 0) {
|
auto tail_add = delta - numThreads * span;
|
||||||
//for span == 1 , tail_add is 0
|
Nd4jLong begin = 0;
|
||||||
last = tail_add;
|
Nd4jLong end = 0;
|
||||||
span1 = span + 1;
|
|
||||||
span2 = span;
|
//we will try enqueu bigger parts first
|
||||||
|
decltype(span) span1, span2;
|
||||||
|
int last = 0;
|
||||||
|
if (tail_add >= 0) {
|
||||||
|
//for span == 1 , tail_add is 0
|
||||||
|
last = tail_add;
|
||||||
|
span1 = span + 1;
|
||||||
|
span2 = span;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
last = numThreads + tail_add;// -std::abs(tail_add);
|
||||||
|
span1 = span;
|
||||||
|
span2 = span - 1;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < last; i++) {
|
||||||
|
end = begin + span1 * increment;
|
||||||
|
// putting the task into the queue for a given thread
|
||||||
|
thread_spans[i].start = begin;
|
||||||
|
thread_spans[i].end = end;
|
||||||
|
begin = end;
|
||||||
|
}
|
||||||
|
for (int i = last; i < numThreads - 1; i++) {
|
||||||
|
end = begin + span2 * increment;
|
||||||
|
// putting the task into the queue for a given thread
|
||||||
|
thread_spans[i].start = begin;
|
||||||
|
thread_spans[i].end = end;
|
||||||
|
begin = end;
|
||||||
|
}
|
||||||
|
//for last one enqueue last offset as stop
|
||||||
|
//we need it in case our ((stop-start) % increment ) > 0
|
||||||
|
thread_spans[numThreads - 1].start = begin;
|
||||||
|
thread_spans[numThreads - 1].end = stop;
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
if (tryAcquire(numThreads)) {
|
||||||
|
#pragma omp parallel for num_threads(numThreads)
|
||||||
|
for (size_t j = 0; j < numThreads; j++) {
|
||||||
|
function(j, thread_spans[j].start, thread_spans[j].end, increment);
|
||||||
}
|
}
|
||||||
else {
|
freeThreads(numThreads);
|
||||||
last = numThreads + tail_add;// -std::abs(tail_add);
|
|
||||||
span1 = span;
|
|
||||||
span2 = span - 1;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < last; i++) {
|
|
||||||
end = begin + span1 * increment;
|
|
||||||
// putting the task into the queue for a given thread
|
|
||||||
ticket->enqueue(i, numThreads, function, begin, end, increment);
|
|
||||||
begin = end;
|
|
||||||
}
|
|
||||||
for (int i = last; i < numThreads - 1; i++) {
|
|
||||||
end = begin + span2 * increment;
|
|
||||||
// putting the task into the queue for a given thread
|
|
||||||
ticket->enqueue(i, numThreads, function, begin, end, increment);
|
|
||||||
begin = end;
|
|
||||||
}
|
|
||||||
//for last one enqueue last offset as stop
|
|
||||||
//we need it in case our ((stop-start) % increment ) > 0
|
|
||||||
ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, increment);
|
|
||||||
// block and wait till all threads finished the job
|
|
||||||
ticket->waitAndRelease();
|
|
||||||
// we tell that parallelism request succeeded
|
|
||||||
return numThreads;
|
return numThreads;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// if there were no threads available - we'll execute function right within current thread
|
|
||||||
function(0, start, stop, increment);
|
function(0, start, stop, increment);
|
||||||
// we tell that parallelism request declined
|
// we tell that parallelism request declined
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
auto ticket = sd::ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
|
||||||
|
for (size_t j = 0; j < numThreads; j++) {
|
||||||
|
ticket->enqueue(j, numThreads, function, thread_spans[j].start, thread_spans[j].end, increment);
|
||||||
|
}
|
||||||
|
// block and wait till all threads finished the job
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
// we tell that parallelism request succeeded
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, start, stop, increment);
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
Ticket::Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues) {
|
Ticket::Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues) {
|
||||||
_acquired = true;
|
_acquired = true;
|
||||||
_queues = queues;
|
_queues = queues;
|
||||||
|
@ -38,7 +38,7 @@ namespace samediff {
|
||||||
return _acquired;
|
return _acquired;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) {
|
void Ticket::enqueue(int thread_id, sd::CallableWithArguments *callable) {
|
||||||
_queues[thread_id]->put(callable);
|
_queues[thread_id]->put(callable);
|
||||||
_callables.emplace_back(callable);
|
_callables.emplace_back(callable);
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ namespace samediff {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void Ticket::attach(uint32_t thread_id, samediff::CallableInterface *interface) {
|
void Ticket::attach(uint32_t thread_id, sd::CallableInterface *interface) {
|
||||||
_interfaces[thread_id] = interface;
|
_interfaces[thread_id] = interface;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -112,7 +112,7 @@ namespace sd {
|
||||||
sd::random::RandomBuffer* getRNG();
|
sd::random::RandomBuffer* getRNG();
|
||||||
void setRNG(sd::random::RandomBuffer* rng);
|
void setRNG(sd::random::RandomBuffer* rng);
|
||||||
|
|
||||||
void setTargetEngine(samediff::Engine engine);
|
void setTargetEngine(sd::Engine engine);
|
||||||
|
|
||||||
VariableSpace *getVariableSpace();
|
VariableSpace *getVariableSpace();
|
||||||
|
|
||||||
|
@ -228,8 +228,8 @@ namespace sd {
|
||||||
void setShapeFunctionOverride(bool reallyOverride);
|
void setShapeFunctionOverride(bool reallyOverride);
|
||||||
bool shapeFunctionOverride();
|
bool shapeFunctionOverride();
|
||||||
|
|
||||||
samediff::ExecutionMode executionMode();
|
sd::ExecutionMode executionMode();
|
||||||
void setExecutionMode(samediff::ExecutionMode executionMode);
|
void setExecutionMode(sd::ExecutionMode executionMode);
|
||||||
|
|
||||||
bool isTraining();
|
bool isTraining();
|
||||||
bool isInference();
|
bool isInference();
|
||||||
|
|
|
@ -64,9 +64,9 @@ namespace sd {
|
||||||
bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN();
|
bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN();
|
||||||
|
|
||||||
// target engine for execution
|
// target engine for execution
|
||||||
samediff::Engine _engine = DEFAULT_ENGINE;
|
sd::Engine _engine = DEFAULT_ENGINE;
|
||||||
|
|
||||||
samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED;
|
sd::ExecutionMode _execMode = sd::ExecutionMode::MODE_UNDEFINED;
|
||||||
public:
|
public:
|
||||||
explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false);
|
explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false);
|
||||||
~ContextPrototype() = default;
|
~ContextPrototype() = default;
|
||||||
|
@ -99,7 +99,7 @@ namespace sd {
|
||||||
std::vector<sd::DataType>* getDArguments();
|
std::vector<sd::DataType>* getDArguments();
|
||||||
std::vector<int>* getAxis();
|
std::vector<int>* getAxis();
|
||||||
|
|
||||||
samediff::Engine engine();
|
sd::Engine engine();
|
||||||
|
|
||||||
size_t numT();
|
size_t numT();
|
||||||
size_t numI();
|
size_t numI();
|
||||||
|
|
|
@ -107,7 +107,7 @@ namespace sd {
|
||||||
delete _context;
|
delete _context;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Context::setTargetEngine(samediff::Engine engine) {
|
void Context::setTargetEngine(sd::Engine engine) {
|
||||||
_engine = engine;
|
_engine = engine;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -548,20 +548,20 @@ namespace sd {
|
||||||
return _shapeFunctionOverride;
|
return _shapeFunctionOverride;
|
||||||
}
|
}
|
||||||
|
|
||||||
samediff::ExecutionMode Context::executionMode() {
|
sd::ExecutionMode Context::executionMode() {
|
||||||
return _execMode;
|
return _execMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Context::setExecutionMode(samediff::ExecutionMode executionMode) {
|
void Context::setExecutionMode(sd::ExecutionMode executionMode) {
|
||||||
_execMode = executionMode;
|
_execMode = executionMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Context::isTraining() {
|
bool Context::isTraining() {
|
||||||
return _execMode == samediff::ExecutionMode::MODE_TRAINING;
|
return _execMode == sd::ExecutionMode::MODE_TRAINING;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Context::isInference() {
|
bool Context::isInference() {
|
||||||
return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
|
return _execMode == sd::ExecutionMode::MODE_INFERENCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) {
|
void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) {
|
||||||
|
|
|
@ -59,7 +59,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
samediff::Engine ContextPrototype::engine() {
|
sd::Engine ContextPrototype::engine() {
|
||||||
return _engine;
|
return _engine;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -511,7 +511,7 @@ namespace sd {
|
||||||
|
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::EWS1: {
|
case LoopKind::EWS1: {
|
||||||
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
|
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
|
||||||
int64_t start = span.startX(), stop = span.stopX();
|
int64_t start = span.startX(), stop = span.stopX();
|
||||||
|
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
|
@ -524,7 +524,7 @@ namespace sd {
|
||||||
const uint xEws = shape::elementWiseStride(xShapeInfo);
|
const uint xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
const uint zEws = shape::elementWiseStride(zShapeInfo);
|
const uint zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
|
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
|
||||||
int64_t start = span.startX(), stop = span.stopX();
|
int64_t start = span.startX(), stop = span.stopX();
|
||||||
|
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
|
@ -538,7 +538,7 @@ namespace sd {
|
||||||
uint castXShapeInfo[MAX_RANK];
|
uint castXShapeInfo[MAX_RANK];
|
||||||
const bool canCastX = sd::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, castXShapeInfo);
|
const bool canCastX = sd::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, castXShapeInfo);
|
||||||
|
|
||||||
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
|
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
|
||||||
int64_t start = span.startX(), stop = span.stopX();
|
int64_t start = span.startX(), stop = span.stopX();
|
||||||
|
|
||||||
if (zEws > 1) {
|
if (zEws > 1) {
|
||||||
|
@ -558,7 +558,7 @@ namespace sd {
|
||||||
|
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case LoopKind::RANK1: {
|
case LoopKind::RANK1: {
|
||||||
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
|
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
|
||||||
|
|
||||||
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
||||||
z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams);
|
z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams);
|
||||||
|
@ -570,8 +570,8 @@ namespace sd {
|
||||||
auto uXShape0 = static_cast<uint>(xShape[0]);
|
auto uXShape0 = static_cast<uint>(xShape[0]);
|
||||||
auto uXShape1 = static_cast<uint>(xShape[1]);
|
auto uXShape1 = static_cast<uint>(xShape[1]);
|
||||||
|
|
||||||
auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
|
auto loop = sd::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
|
||||||
auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
|
auto span = sd::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
|
||||||
|
|
||||||
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) {
|
for (auto i0 = span.startX(); i0 < span.stopX(); i0++) {
|
||||||
auto z0 = i0 * zStride[0];
|
auto z0 = i0 * zStride[0];
|
||||||
|
@ -589,8 +589,8 @@ namespace sd {
|
||||||
auto uXShape1 = xShape[1];
|
auto uXShape1 = xShape[1];
|
||||||
auto uXShape2 = xShape[2];
|
auto uXShape2 = xShape[2];
|
||||||
|
|
||||||
auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
|
auto loop = sd::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1);
|
||||||
auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
|
auto span = sd::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1);
|
||||||
|
|
||||||
|
|
||||||
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
||||||
|
@ -611,8 +611,8 @@ namespace sd {
|
||||||
auto uXShape2 = xShape[2];
|
auto uXShape2 = xShape[2];
|
||||||
auto uXShape3 = xShape[3];
|
auto uXShape3 = xShape[3];
|
||||||
|
|
||||||
auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
|
auto loop = sd::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
|
||||||
auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
|
auto span = sd::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
|
||||||
|
|
||||||
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
||||||
for (auto i1 = span.startY(); i1 < span.stopY(); i1++)
|
for (auto i1 = span.startY(); i1 < span.stopY(); i1++)
|
||||||
|
@ -634,8 +634,8 @@ namespace sd {
|
||||||
auto uXShape3 = xShape[3];
|
auto uXShape3 = xShape[3];
|
||||||
auto uXShape4 = xShape[4];
|
auto uXShape4 = xShape[4];
|
||||||
|
|
||||||
auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
|
auto loop = sd::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2);
|
||||||
auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
|
auto span = sd::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1);
|
||||||
|
|
||||||
|
|
||||||
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
for (auto i0 = span.startX(); i0 < span.stopX(); i0++)
|
||||||
|
@ -666,7 +666,7 @@ namespace sd {
|
||||||
bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
auto span = samediff::Span::build(threadId, numThreads, 0, len, 1);
|
auto span = sd::Span::build(threadId, numThreads, 0, len, 1);
|
||||||
|
|
||||||
for (auto i = span.startX(); i < span.stopX(); i++) {
|
for (auto i = span.startX(); i < span.stopX(); i++) {
|
||||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
|
|
|
@ -93,7 +93,7 @@ static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, cLen);
|
sd::Threads::parallel_tad(func, 0, cLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, M);
|
sd::Threads::parallel_tad(func, 0, M);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -477,7 +477,7 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, cLen);
|
sd::Threads::parallel_tad(func, 0, cLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -669,7 +669,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, M, 1, 0, N, 1);
|
sd::Threads::parallel_tad(func, 0, M, 1, 0, N, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -703,7 +703,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, M);
|
sd::Threads::parallel_tad(func, 0, M);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -191,7 +191,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -248,7 +248,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -299,7 +299,7 @@ void sd::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,7 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_m
|
||||||
|
|
||||||
return _stdDevValue;
|
return _stdDevValue;
|
||||||
};
|
};
|
||||||
_stdDevValue = samediff::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf());
|
_stdDevValue = sd::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf());
|
||||||
|
|
||||||
info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf());
|
info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf());
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -237,7 +237,7 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
|
||||||
auto yLen = shape::length(hYShapeInfo);
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
auto numTads = yLen / xLen;
|
auto numTads = yLen / xLen;
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -273,7 +273,7 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
|
||||||
auto yLen = shape::length(hYShapeInfo);
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
auto numTads = xLen / yLen;
|
auto numTads = xLen / yLen;
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
|
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
|
||||||
|
@ -308,7 +308,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
|
||||||
auto yLen = shape::length(hYShapeInfo);
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
auto numTads = yLen / xLen;
|
auto numTads = yLen / xLen;
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,7 +348,7 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
|
||||||
auto yLen = shape::length(hYShapeInfo);
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
auto numTads = xLen / yLen;
|
auto numTads = xLen / yLen;
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
|
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
|
||||||
|
@ -384,7 +384,7 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
|
||||||
auto yLen = shape::length(hYShapeInfo);
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
auto numTads = yLen / xLen;
|
auto numTads = yLen / xLen;
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -427,7 +427,7 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto zLen = shape::length(hZShapeInfo);
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -462,7 +462,7 @@ void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto zLen = shape::length(hZShapeInfo);
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -495,7 +495,7 @@ void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto zLen = shape::length(hZShapeInfo);
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,7 +534,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -562,7 +562,7 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -590,7 +590,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -618,7 +618,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
sd::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -791,7 +791,7 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -820,7 +820,7 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -861,7 +861,7 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
sd::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -905,7 +905,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto zLen = shape::length(hZShapeInfo);
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -942,7 +942,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto yLen = shape::length(hScalarShapeInfo);
|
auto yLen = shape::length(hScalarShapeInfo);
|
||||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -976,7 +976,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto zLen = shape::length(hZShapeInfo);
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1012,7 +1012,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto yLen = shape::length(hScalarShapeInfo);
|
auto yLen = shape::length(hScalarShapeInfo);
|
||||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1044,7 +1044,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto zLen = shape::length(hZShapeInfo);
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1080,7 +1080,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
|
||||||
};
|
};
|
||||||
|
|
||||||
auto yLen = shape::length(hScalarShapeInfo);
|
auto yLen = shape::length(hScalarShapeInfo);
|
||||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
sd::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1193,7 +1193,7 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1215,7 +1215,7 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1243,7 +1243,7 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1266,7 +1266,7 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1288,7 +1288,7 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
sd::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -1318,7 +1318,7 @@ void pullRowsGeneric(void *vx,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, n, 1, _threads);
|
sd::Threads::parallel_tad(func, 0, n, 1, _threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void pullRows(Nd4jPointer *extraPointers,
|
void pullRows(Nd4jPointer *extraPointers,
|
||||||
|
@ -1377,7 +1377,7 @@ void tearGeneric(void *vx,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func,0, numTads);
|
sd::Threads::parallel_tad(func,0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void tear(Nd4jPointer *extraPointers,
|
void tear(Nd4jPointer *extraPointers,
|
||||||
|
@ -1530,7 +1530,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, N);
|
sd::Threads::parallel_tad(func, 0, N);
|
||||||
}
|
}
|
||||||
|
|
||||||
void shuffle(Nd4jPointer *extras,
|
void shuffle(Nd4jPointer *extras,
|
||||||
|
@ -1944,7 +1944,7 @@ FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer
|
||||||
return cnt;
|
return cnt;
|
||||||
};
|
};
|
||||||
|
|
||||||
return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
|
return sd::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2653,7 +2653,7 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_do(func);
|
sd::Threads::parallel_do(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2812,7 +2812,7 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
|
||||||
if (execMode < 0 || execMode > 2)
|
if (execMode < 0 || execMode > 2)
|
||||||
execMode = 0;
|
execMode = 0;
|
||||||
|
|
||||||
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
ptr->setExecutionMode((sd::ExecutionMode) execMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ctxPurge(OpaqueContext* ptr) {
|
void ctxPurge(OpaqueContext* ptr) {
|
||||||
|
|
|
@ -3799,7 +3799,7 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
|
||||||
if (execMode < 0 || execMode > 2)
|
if (execMode < 0 || execMode > 2)
|
||||||
execMode = 0;
|
execMode = 0;
|
||||||
|
|
||||||
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
ptr->setExecutionMode((sd::ExecutionMode) execMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
|
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
|
||||||
|
|
|
@ -60,7 +60,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
sd::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, nLen, 1);
|
sd::Threads::parallel_tad(func, 0, nLen, 1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
sd::Threads::parallel_for(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
|
@ -200,7 +200,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
sd::Threads::parallel_for(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
|
@ -263,7 +263,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, zLen);
|
sd::Threads::parallel_for(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
|
|
@ -79,7 +79,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, len, 1, maxThreads);
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
for (int e = 0; e < maxThreads; e++)
|
||||||
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
|
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
|
||||||
|
@ -95,7 +95,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, len, 1, maxThreads);
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
for (int e = 0; e < maxThreads; e++)
|
||||||
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
|
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
|
||||||
|
|
|
@ -67,7 +67,7 @@ namespace functions {
|
||||||
z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments);
|
z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
|
@ -81,7 +81,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
@ -100,7 +100,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ namespace functions {
|
||||||
z[i] = OpClass::op(x[i], i, length, rng, extraArguments);
|
z[i] = OpClass::op(x[i], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
@ -203,7 +203,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -220,7 +220,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,7 +245,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
sd::OmpLaunchHelper info(length);
|
sd::OmpLaunchHelper info(length);
|
||||||
|
@ -261,7 +261,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, length, 1);
|
sd::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -208,13 +208,26 @@ namespace functions {
|
||||||
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
|
||||||
Z intermediate[64];
|
Z intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (auto e = 0; e < maxThreads; e++)
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
intermediate[e] = OpType::startingValue(x);
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
if (xEws == 1) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
|
||||||
|
} else {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
if (xEws == 1) {
|
if (xEws == 1) {
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
|
@ -225,7 +238,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// merge results
|
// merge results
|
||||||
for (int e = 1; e < maxThreads; e++)
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
|
|
@ -72,7 +72,7 @@ namespace functions {
|
||||||
auto startingValue = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
|
||||||
Z intermediate[64];
|
Z intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
|
@ -84,7 +84,7 @@ namespace functions {
|
||||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
// merge results
|
// merge results
|
||||||
for (int e = 1; e < maxThreads; e++)
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
@ -242,13 +242,27 @@ namespace functions {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
|
||||||
Z intermediate[64];
|
Z intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (auto e = 0; e < maxThreads; e++)
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
intermediate[e] = OpType::startingValue(x);
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
if (xEws == 1) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
|
||||||
|
} else {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
if (xEws == 1) {
|
if (xEws == 1) {
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
|
@ -259,7 +273,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// merge results
|
// merge results
|
||||||
for (int e = 1; e < maxThreads; e++)
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
|
|
@ -67,7 +67,7 @@ namespace functions {
|
||||||
auto startingValue = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
|
||||||
Z intermediate[64];
|
Z intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
|
@ -79,7 +79,7 @@ namespace functions {
|
||||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
// merge results
|
// merge results
|
||||||
for (int e = 1; e < maxThreads; e++)
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
@ -231,13 +231,26 @@ namespace functions {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
|
||||||
Z intermediate[64];
|
Z intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (auto e = 0; e < maxThreads; e++)
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
intermediate[e] = OpType::startingValue(x);
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
if (xEws == 1) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
|
||||||
|
} else {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
if (xEws == 1) {
|
if (xEws == 1) {
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
|
@ -248,7 +261,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// merge results
|
// merge results
|
||||||
for (int e = 1; e < maxThreads; e++)
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
|
|
@ -69,7 +69,7 @@ namespace functions {
|
||||||
auto startingValue = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
|
||||||
X intermediate[64];
|
X intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
|
@ -81,7 +81,7 @@ namespace functions {
|
||||||
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
// merge results
|
// merge results
|
||||||
for (int e = 1; e < maxThreads; e++)
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
@ -240,13 +240,26 @@ namespace functions {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxMasterThreads());
|
||||||
X intermediate[64];
|
X intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (auto e = 0; e < maxThreads; e++)
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
intermediate[e] = OpType::startingValue(x);
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
if (xEws == 1) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i], extraParams), extraParams);
|
||||||
|
} else {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_THREADS(maxThreads)
|
||||||
|
for (Nd4jLong i = 0; i < length; i++)
|
||||||
|
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
if (xEws == 1) {
|
if (xEws == 1) {
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
|
@ -257,7 +270,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// merge results
|
// merge results
|
||||||
for (int e = 1; e < maxThreads; e++)
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
|
|
@ -93,7 +93,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
} else {
|
} else {
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
@ -117,7 +117,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
maxThreads = sd::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
// merge step
|
// merge step
|
||||||
|
|
|
@ -187,7 +187,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, resultLength, 1);
|
sd::Threads::parallel_tad(func, 0, resultLength, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, N);
|
sd::Threads::parallel_for(func, 0, N);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -184,7 +184,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 4, flimit);
|
sd::Threads::parallel_for(func, 4, flimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -206,7 +206,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
|
||||||
z[i] = static_cast<T>(static_cast<float>(x[i]));
|
z[i] = static_cast<T>(static_cast<float>(x[i]));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, N);
|
sd::Threads::parallel_for(func, 0, N);
|
||||||
};
|
};
|
||||||
|
|
||||||
template void TypeCast::convertFromThreshold<float>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
|
template void TypeCast::convertFromThreshold<float>(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz);
|
||||||
|
|
|
@ -112,7 +112,7 @@ namespace sd {
|
||||||
*/
|
*/
|
||||||
int prepareOutputs(Context& block);
|
int prepareOutputs(Context& block);
|
||||||
|
|
||||||
virtual samediff::EmptyHandling emptyHandling();
|
virtual sd::EmptyHandling emptyHandling();
|
||||||
public:
|
public:
|
||||||
// for special cases, like BooleanOps
|
// for special cases, like BooleanOps
|
||||||
DeclarableOp();
|
DeclarableOp();
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#ifndef SAMEDIFF_EMPTYHANDLING_H
|
#ifndef SAMEDIFF_EMPTYHANDLING_H
|
||||||
#define SAMEDIFF_EMPTYHANDLING_H
|
#define SAMEDIFF_EMPTYHANDLING_H
|
||||||
|
|
||||||
namespace samediff {
|
namespace sd {
|
||||||
enum EmptyHandling {
|
enum EmptyHandling {
|
||||||
EMPTY_SKIP = 1,
|
EMPTY_SKIP = 1,
|
||||||
EMPTY_EXCEPTION = 2,
|
EMPTY_EXCEPTION = 2,
|
||||||
|
|
|
@ -38,15 +38,15 @@
|
||||||
namespace std {
|
namespace std {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class hash<std::pair<Nd4jLong, samediff::Engine>> {
|
class hash<std::pair<Nd4jLong, sd::Engine>> {
|
||||||
public:
|
public:
|
||||||
size_t operator()(const std::pair<Nd4jLong, samediff::Engine>& k) const;
|
size_t operator()(const std::pair<Nd4jLong, sd::Engine>& k) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
class hash<std::pair<std::string, samediff::Engine>> {
|
class hash<std::pair<std::string, sd::Engine>> {
|
||||||
public:
|
public:
|
||||||
size_t operator()(const std::pair<std::string, samediff::Engine>& k) const;
|
size_t operator()(const std::pair<std::string, sd::Engine>& k) const;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -87,8 +87,8 @@ namespace sd {
|
||||||
std::vector<sd::ops::DeclarableOp *> _uniqueD;
|
std::vector<sd::ops::DeclarableOp *> _uniqueD;
|
||||||
|
|
||||||
// pointers to platform-specific helpers
|
// pointers to platform-specific helpers
|
||||||
MAP_IMPL<std::pair<Nd4jLong, samediff::Engine>, sd::ops::platforms::PlatformHelper*> _helpersLH;
|
MAP_IMPL<std::pair<Nd4jLong, sd::Engine>, sd::ops::platforms::PlatformHelper*> _helpersLH;
|
||||||
MAP_IMPL<std::pair<std::string, samediff::Engine>, sd::ops::platforms::PlatformHelper*> _helpersH;
|
MAP_IMPL<std::pair<std::string, sd::Engine>, sd::ops::platforms::PlatformHelper*> _helpersH;
|
||||||
std::vector<sd::ops::platforms::PlatformHelper*> _uniqueH;
|
std::vector<sd::ops::platforms::PlatformHelper*> _uniqueH;
|
||||||
|
|
||||||
std::mutex _locker;
|
std::mutex _locker;
|
||||||
|
@ -119,13 +119,13 @@ namespace sd {
|
||||||
|
|
||||||
void registerHelper(sd::ops::platforms::PlatformHelper* op);
|
void registerHelper(sd::ops::platforms::PlatformHelper* op);
|
||||||
|
|
||||||
bool hasHelper(Nd4jLong hash, samediff::Engine engine);
|
bool hasHelper(Nd4jLong hash, sd::Engine engine);
|
||||||
|
|
||||||
sd::ops::DeclarableOp* getOperation(const char *name);
|
sd::ops::DeclarableOp* getOperation(const char *name);
|
||||||
sd::ops::DeclarableOp* getOperation(Nd4jLong hash);
|
sd::ops::DeclarableOp* getOperation(Nd4jLong hash);
|
||||||
sd::ops::DeclarableOp* getOperation(std::string &name);
|
sd::ops::DeclarableOp* getOperation(std::string &name);
|
||||||
|
|
||||||
sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine);
|
sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, sd::Engine engine);
|
||||||
|
|
||||||
std::vector<Nd4jLong> getAllHashes();
|
std::vector<Nd4jLong> getAllHashes();
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ namespace sd {
|
||||||
class ND4J_EXPORT PlatformHelper {
|
class ND4J_EXPORT PlatformHelper {
|
||||||
protected:
|
protected:
|
||||||
// target engine for this impl
|
// target engine for this impl
|
||||||
samediff::Engine _engine;
|
sd::Engine _engine;
|
||||||
|
|
||||||
// name of the operation this helper is built for
|
// name of the operation this helper is built for
|
||||||
std::string _name;
|
std::string _name;
|
||||||
|
@ -45,13 +45,13 @@ namespace sd {
|
||||||
// hash of the operation this helper is built for
|
// hash of the operation this helper is built for
|
||||||
Nd4jLong _hash;
|
Nd4jLong _hash;
|
||||||
public:
|
public:
|
||||||
PlatformHelper(const char *name, samediff::Engine engine);
|
PlatformHelper(const char *name, sd::Engine engine);
|
||||||
|
|
||||||
~PlatformHelper() = default;
|
~PlatformHelper() = default;
|
||||||
|
|
||||||
std::string name();
|
std::string name();
|
||||||
|
|
||||||
samediff::Engine engine();
|
sd::Engine engine();
|
||||||
|
|
||||||
Nd4jLong hash();
|
Nd4jLong hash();
|
||||||
|
|
||||||
|
|
|
@ -174,7 +174,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, N);
|
sd::Threads::parallel_tad(func, 0, N);
|
||||||
}
|
}
|
||||||
|
|
||||||
void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) {
|
void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) {
|
||||||
|
|
|
@ -154,7 +154,7 @@ void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alp
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, inputLen);
|
sd::Threads::parallel_for(func, 0, inputLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -565,7 +565,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
//
|
//
|
||||||
samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc);
|
sd::Threads::parallel_aligned_increment(func, 0, total_num, inc);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
//NC...HW case here
|
//NC...HW case here
|
||||||
|
@ -631,7 +631,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
//
|
//
|
||||||
samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc);
|
sd::Threads::parallel_aligned_increment(func, 0, total_num, inc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -55,7 +55,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
||||||
} else {
|
} else {
|
||||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||||
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
||||||
|
@ -84,7 +84,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,7 @@ void bgemm_(const std::vector<NDArray*>& vA, const std::vector<NDArray*>& vB, st
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, vaSize);
|
sd::Threads::parallel_tad(func, 0, vaSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
|
||||||
delete []zOffsets;
|
delete []zOffsets;
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_do(func, info._numThreads);
|
sd::Threads::parallel_do(func, info._numThreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -178,7 +178,7 @@ static void batchnorm2_(const NDArray* input, const NDArray* mean, const NDArray
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf());
|
sd::Threads::parallel_for(func, 0, input->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -121,7 +121,7 @@ static void betaIncForArray(sd::LaunchContext * context, const NDArray& a, const
|
||||||
output.t<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i));
|
output.t<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, xLen);
|
sd::Threads::parallel_for(func, 0, xLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -89,7 +89,7 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, bS);
|
sd::Threads::parallel_tad(func, 0, bS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
return sum;
|
return sum;
|
||||||
};
|
};
|
||||||
sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
|
sumt = sd::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
|
||||||
} else {
|
} else {
|
||||||
//PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum)
|
//PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum)
|
||||||
auto func = PRAGMA_REDUCE_LONG {
|
auto func = PRAGMA_REDUCE_LONG {
|
||||||
|
@ -53,7 +53,7 @@ namespace sd {
|
||||||
|
|
||||||
return sum;
|
return sum;
|
||||||
};
|
};
|
||||||
sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
|
sumt = sd::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//nd4j_printf("Sum: %lld\n", sumt)
|
//nd4j_printf("Sum: %lld\n", sumt)
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, lLen);
|
sd::Threads::parallel_for(func, 0, lLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
||||||
|
|
|
@ -101,7 +101,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1);
|
||||||
//func(0, 0, bS, 1, 0, oD, 1);
|
//func(0, 0, bS, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -215,7 +215,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, bS);
|
sd::Threads::parallel_tad(func, 0, bS);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
|
@ -251,7 +251,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, bS);
|
sd::Threads::parallel_tad(func, 0, bS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -606,7 +606,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -663,7 +663,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -716,7 +716,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -777,7 +777,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -860,7 +860,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 1) { // avg
|
else if(poolingMode == 1) { // avg
|
||||||
|
@ -914,7 +914,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 2) { // pnorm
|
else if(poolingMode == 2) { // pnorm
|
||||||
|
@ -963,7 +963,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
||||||
|
@ -1068,7 +1068,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 1) { // avg
|
else if(poolingMode == 1) { // avg
|
||||||
|
@ -1131,7 +1131,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 2) { // pnorm
|
else if(poolingMode == 2) { // pnorm
|
||||||
|
@ -1191,7 +1191,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
||||||
|
@ -1321,7 +1321,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 1) { // avg
|
else if(poolingMode == 1) { // avg
|
||||||
|
@ -1379,7 +1379,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 2) { // pnorm
|
else if(poolingMode == 2) { // pnorm
|
||||||
|
@ -1466,7 +1466,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
||||||
|
@ -1618,7 +1618,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 1) { // avg
|
else if(poolingMode == 1) { // avg
|
||||||
|
@ -1679,7 +1679,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
/*************************************************************************/
|
/*************************************************************************/
|
||||||
else if(poolingMode == 2) { // pnorm
|
else if(poolingMode == 2) { // pnorm
|
||||||
|
@ -1761,7 +1761,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode);
|
||||||
|
|
|
@ -115,7 +115,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, cropHeight);
|
sd::Threads::parallel_for(func, 0, cropHeight);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, tads);
|
sd::Threads::parallel_tad(func, 0, tads);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, total_count);
|
sd::Threads::parallel_for(func, 0, total_count);
|
||||||
} else {
|
} else {
|
||||||
const int total_count = batch_size * input_depth_by_input_area;
|
const int total_count = batch_size * input_depth_by_input_area;
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, total_count);
|
sd::Threads::parallel_for(func, 0, total_count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ static void diGamma_(const NDArray& x, NDArray& z) {
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
z.p(i, diGammaScalar<T>(x.e<T>(i)));
|
z.p(i, diGammaScalar<T>(x.e<T>(i)));
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
sd::Threads::parallel_for(func, 0, x.lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) {
|
void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) {
|
||||||
|
|
|
@ -87,7 +87,7 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||||
|
|
|
@ -43,7 +43,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, inLen);
|
sd::Threads::parallel_for(func, 0, inLen);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf());
|
sd::Threads::parallel_for(func, 0, input->lengthOf());
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, outSize);
|
sd::Threads::parallel_tad(func, 0, outSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -177,7 +177,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, gradsSize);
|
sd::Threads::parallel_tad(func, 0, gradsSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
outputList[1]->assign(indices);
|
outputList[1]->assign(indices);
|
||||||
|
|
|
@ -82,7 +82,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, batchCount);
|
sd::Threads::parallel_tad(func, 0, batchCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
|
||||||
output->p(i, input->e(indices->e<Nd4jLong>(i)));
|
output->p(i, input->e(indices->e<Nd4jLong>(i)));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, output->lengthOf());
|
sd::Threads::parallel_for(func, 0, output->lengthOf());
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -96,7 +96,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
|
||||||
memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
|
memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
@ -112,7 +112,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,7 +148,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
|
||||||
std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
|
std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -167,7 +167,7 @@ void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* in
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
|
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf);
|
||||||
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
|
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e++) {
|
for (auto e = start; e < stop; e++) {
|
||||||
|
@ -75,7 +75,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
|
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf);
|
||||||
} else {
|
} else {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e++) {
|
for (auto e = start; e < stop; e++) {
|
||||||
|
@ -86,7 +86,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf);
|
maxThreads = sd::Threads::parallel_for(func, 0, lengthOf);
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate intermediate variables into output array
|
// accumulate intermediate variables into output array
|
||||||
|
|
|
@ -54,7 +54,7 @@ namespace sd {
|
||||||
tempBuffer[b] = r;
|
tempBuffer[b] = r;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, numBlocks);
|
sd::Threads::parallel_tad(func, 0, numBlocks);
|
||||||
|
|
||||||
// we replace pointer with intermediate one, and repeat only one chunk left
|
// we replace pointer with intermediate one, and repeat only one chunk left
|
||||||
int iterationCount = 0;
|
int iterationCount = 0;
|
||||||
|
@ -76,7 +76,7 @@ namespace sd {
|
||||||
tempResult[b] = r;
|
tempResult[b] = r;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func2, 0, numBlocks);
|
sd::Threads::parallel_tad(func2, 0, numBlocks);
|
||||||
|
|
||||||
|
|
||||||
iterationCount++;
|
iterationCount++;
|
||||||
|
|
|
@ -90,7 +90,7 @@ static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray&
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray&
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -149,7 +149,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, batchSize);
|
sd::Threads::parallel_tad(func, 0, batchSize);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,7 @@ namespace helpers {
|
||||||
interpolationData[i]._interpolarValue = in - in_f;
|
interpolationData[i]._interpolarValue = in - in_f;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, outSize);
|
sd::Threads::parallel_for(func, 0, outSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -240,7 +240,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, batchSize);
|
sd::Threads::parallel_tad(func, 0, batchSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
|
@ -285,7 +285,7 @@ namespace helpers {
|
||||||
xs[i]._topIndex *= channels;
|
xs[i]._topIndex *= channels;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, xsSize);
|
sd::Threads::parallel_for(func, 0, xsSize);
|
||||||
|
|
||||||
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
|
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -323,7 +323,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
|
sd::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -427,7 +427,7 @@ namespace helpers {
|
||||||
coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, kTableSize);
|
sd::Threads::parallel_for(func, 0, kTableSize);
|
||||||
return coeffs_table;
|
return coeffs_table;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -541,7 +541,7 @@ namespace helpers {
|
||||||
x_wai._index3);
|
x_wai._index3);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
|
sd::Threads::parallel_for(func, 0, resizer_state.outWidth);
|
||||||
} else {
|
} else {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto x = start; x < stop; ++x) {
|
for (auto x = start; x < stop; ++x) {
|
||||||
|
@ -552,7 +552,7 @@ namespace helpers {
|
||||||
x_wai._index3);
|
x_wai._index3);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
|
sd::Threads::parallel_for(func, 0, resizer_state.outWidth);
|
||||||
}
|
}
|
||||||
// Scale the values so they can be used as offsets into buffers.
|
// Scale the values so they can be used as offsets into buffers.
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
@ -563,7 +563,7 @@ namespace helpers {
|
||||||
(*x_wais)[x]._index3 *= resizer_state.channels;
|
(*x_wais)[x]._index3 *= resizer_state.channels;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
|
sd::Threads::parallel_for(func, 0, resizer_state.outWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -774,7 +774,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, batchNum);
|
sd::Threads::parallel_tad(func, 0, batchNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
// simplified bicubic resize without antialiasing
|
// simplified bicubic resize without antialiasing
|
||||||
|
@ -950,7 +950,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1);
|
sd::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
@ -981,7 +981,7 @@ namespace helpers {
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1);
|
sd::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1);
|
||||||
|
|
||||||
resizeArea<X>(st, xCached, image, output);
|
resizeArea<X>(st, xCached, image, output);
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, output.lengthOf(), 1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, con
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3);
|
sd::Threads::parallel_for(func, 0, input.lengthOf(), 3);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, con
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||||
|
@ -165,7 +165,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,7 +196,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
sd::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||||
|
@ -222,7 +222,7 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -195,7 +195,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, tads);
|
sd::Threads::parallel_tad(func, 0, tads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* outpu
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
@ -134,7 +134,7 @@ static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* outpu
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -242,7 +242,7 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -317,7 +317,7 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
sd::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
}
|
}
|
||||||
gradI *= gradO;
|
gradI *= gradO;
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,7 +130,7 @@ static void fusedTanh(NDArray *z, NDArray *i, NDArray *c, const NDArray *cLast,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, uLen);
|
sd::Threads::parallel_for(func, 0, uLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -54,7 +54,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(loop, 0, n, 1);
|
sd::Threads::parallel_tad(loop, 0, n, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,8 +79,8 @@ namespace helpers {
|
||||||
invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
|
invertedMatrix->t<T>(i, i - 1) -= (inputMatrix->t<T>(i, i - 1) * invertedMatrix->t<T>(i - 1, i - 1) / inputMatrix->t<T>(i, i));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
|
sd::Threads::parallel_for(invertDiagonals, 0, n, 1);
|
||||||
samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1);
|
sd::Threads::parallel_for(invertSubDiagonals, 1, n, 1);
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
for (int i = 1; i < n; i++) {
|
for (int i = 1; i < n; i++) {
|
||||||
|
@ -118,8 +118,8 @@ namespace helpers {
|
||||||
inputMatrix->t<T>(i, i));
|
inputMatrix->t<T>(i, i));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(invertDiagonals, 0, n, 1);
|
sd::Threads::parallel_for(invertDiagonals, 0, n, 1);
|
||||||
samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1);
|
sd::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1);
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
for (auto i = n - 2; i >= 0; i--) {
|
for (auto i = n - 2; i >= 0; i--) {
|
||||||
|
@ -225,7 +225,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
//samediff::Threads::parallel_for(loop, column, rowNum, 1);
|
//sd::Threads::parallel_for(loop, column, rowNum, 1);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
sd::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -327,7 +327,7 @@ namespace helpers {
|
||||||
luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
|
luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
|
sd::Threads::parallel_for(loop, 0, outputs.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
|
||||||
|
|
|
@ -63,7 +63,7 @@ void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& outp
|
||||||
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
|
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, xLen);
|
sd::Threads::parallel_for(func, 0, xLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -51,7 +51,7 @@ int _matrixDiagPart(const NDArray* input, NDArray* output) {
|
||||||
listOut.at(i)->p(j, listDiag.at(i)->e<T>(j, j));
|
listOut.at(i)->p(j, listDiag.at(i)->e<T>(j, j));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, lO);
|
sd::Threads::parallel_tad(func, 0, lO);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, oL);
|
sd::Threads::parallel_for(func, 0, oL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
} else {
|
} else {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e++) {
|
for (auto e = start; e < stop; e++) {
|
||||||
|
@ -88,7 +88,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
sd::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ static void polyGamma_(sd::LaunchContext * context, const NDArray& n, const NDAr
|
||||||
output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i)));
|
output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
sd::Threads::parallel_for(func, 0, x.lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace helpers {
|
||||||
resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0));
|
resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1);
|
sd::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1);
|
sd::Threads::parallel_tad(batching, 0, listOutQ.size(), 1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -197,7 +197,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
|
sd::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
|
||||||
rng.rewindH(output.lengthOf()*numOfClassX);
|
rng.rewindH(output.lengthOf()*numOfClassX);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -42,7 +42,7 @@ static void _range(const NDArray& start, const NDArray& delta, NDArray& outVecto
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
buff[i] = s + i * d;
|
buff[i] = s + i * d;
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, len);
|
sd::Threads::parallel_for(func, 0, len);
|
||||||
}
|
}
|
||||||
|
|
||||||
void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) {
|
void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) {
|
||||||
|
|
|
@ -59,7 +59,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
|
||||||
swap(inArr, e, idx);
|
swap(inArr, e, idx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
|
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
|
||||||
}
|
}
|
||||||
else if (inEWS > 1) {
|
else if (inEWS > 1) {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
@ -70,7 +70,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
|
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
|
sd::Threads::parallel_for(func, 0, numOfElemsToReverse / 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -96,14 +96,14 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
|
||||||
for (Nd4jLong e = start; e < stop; e++)
|
for (Nd4jLong e = start; e < stop; e++)
|
||||||
outArr[sLength - e] = inArr[e];
|
outArr[sLength - e] = inArr[e];
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
|
sd::Threads::parallel_for(func, 0, numOfElemsToReverse);
|
||||||
|
|
||||||
if(inLength != numOfElemsToReverse) {
|
if(inLength != numOfElemsToReverse) {
|
||||||
auto f2 = PRAGMA_THREADS_FOR {
|
auto f2 = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e++)
|
for (auto e = start; e < stop; e++)
|
||||||
outArr[e] = inArr[e];
|
outArr[e] = inArr[e];
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
|
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) {
|
else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) {
|
||||||
|
@ -112,14 +112,14 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
|
||||||
for (auto e = start; e < stop; e++)
|
for (auto e = start; e < stop; e++)
|
||||||
outArr[(sLength - e) * outEWS] = inArr[e * inEWS];
|
outArr[(sLength - e) * outEWS] = inArr[e * inEWS];
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
|
sd::Threads::parallel_for(func, 0, numOfElemsToReverse);
|
||||||
|
|
||||||
if(inLength != numOfElemsToReverse) {
|
if(inLength != numOfElemsToReverse) {
|
||||||
auto f2 = PRAGMA_THREADS_FOR {
|
auto f2 = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e++)
|
for (auto e = start; e < stop; e++)
|
||||||
outArr[e * outEWS] = inArr[e * inEWS];
|
outArr[e * outEWS] = inArr[e * inEWS];
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
|
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -131,7 +131,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
|
||||||
outArr[outOffset] = inArr[inOffset];
|
outArr[outOffset] = inArr[inOffset];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, numOfElemsToReverse);
|
sd::Threads::parallel_for(func, 0, numOfElemsToReverse);
|
||||||
|
|
||||||
if(inLength != numOfElemsToReverse) {
|
if(inLength != numOfElemsToReverse) {
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ static void reverseArray(sd::LaunchContext * context, void *vinArr, Nd4jLong *in
|
||||||
outArr[outOffset] = inArr[inOffset];
|
outArr[outOffset] = inArr[inOffset];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
|
sd::Threads::parallel_for(f2, numOfElemsToReverse, inLength);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ static void batchToSpace_(const NDArray& input, NDArray& output, const uint crop
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES);
|
||||||
|
@ -128,7 +128,7 @@ static void batchToSpaceND_(const NDArray& input, const NDArray& crop, NDArray&
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, (const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, (const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);
|
||||||
|
@ -234,7 +234,7 @@ static void spaceToBatch_(const NDArray& input, NDArray& output, const uint padB
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
|
sd::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES);
|
||||||
|
@ -327,7 +327,7 @@ static void spaceToBatchND_(const NDArray& input, const NDArray& padding, NDArra
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, (const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, (const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES);
|
||||||
|
|
|
@ -69,7 +69,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, total_count);
|
sd::Threads::parallel_for(func, 0, total_count);
|
||||||
} else {
|
} else {
|
||||||
const int total_count = batch_size * output_depth_by_output_area;
|
const int total_count = batch_size * output_depth_by_output_area;
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, total_count);
|
sd::Threads::parallel_for(func, 0, total_count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, indices.lengthOf());
|
sd::Threads::parallel_for(func, 0, indices.lengthOf());
|
||||||
|
|
||||||
return numOfBadIndx;
|
return numOfBadIndx;
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indic
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
else { // outRank > 1
|
else { // outRank > 1
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indic
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& ind
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
sd::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1});
|
std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1});
|
||||||
|
@ -154,7 +154,7 @@ void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& ind
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
sd::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,7 +176,7 @@ void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, indicesLen);
|
sd::Threads::parallel_for(func, 0, indicesLen);
|
||||||
} else {
|
} else {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i++) {
|
for (auto i = start; i < stop; i++) {
|
||||||
|
@ -186,7 +186,7 @@ void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, indicesLen);
|
sd::Threads::parallel_for(func, 0, indicesLen);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -173,7 +173,7 @@ namespace helpers {
|
||||||
meanV.p<T>(e, meanV.e<T>(e) + listOfTensors.at(i)->e<T>(e));
|
meanV.p<T>(e, meanV.e<T>(e) + listOfTensors.at(i)->e<T>(e));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, meanT->lengthOf());
|
sd::Threads::parallel_for(func, 0, meanT->lengthOf());
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
@ -227,7 +227,7 @@ namespace helpers {
|
||||||
sumT->p(e, sumT->e<T>(e) + listOfTensors.at(i)->e<T>(e));
|
sumT->p(e, sumT->e<T>(e) + listOfTensors.at(i)->e<T>(e));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
|
sd::Threads::parallel_for(func, 0, sumT->lengthOf());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
idx = indices->e<int>(i);
|
idx = indices->e<int>(i);
|
||||||
|
@ -276,7 +276,7 @@ namespace helpers {
|
||||||
sumT->p(e, sumT->e<T>(e) * listOfTensors.at(i)->e<T>(e));
|
sumT->p(e, sumT->e<T>(e) * listOfTensors.at(i)->e<T>(e));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
|
sd::Threads::parallel_for(func, 0, sumT->lengthOf());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
idx = indices->e<int>(i);
|
idx = indices->e<int>(i);
|
||||||
|
@ -631,7 +631,7 @@ namespace helpers {
|
||||||
output->p(e, gradOut->e<T>(classNum));
|
output->p(e, gradOut->e<T>(classNum));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, loop_size);
|
sd::Threads::parallel_for(func, 0, loop_size);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
std::vector<int> restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
std::vector<int> restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
@ -658,7 +658,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
|
sd::Threads::parallel_tad(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
|
@ -681,7 +681,7 @@ namespace helpers {
|
||||||
output->p(e, gradOut->e<double>(classNum));
|
output->p(e, gradOut->e<double>(classNum));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf());
|
sd::Threads::parallel_for(func, 0, input->lengthOf());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
@ -711,7 +711,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
|
sd::Threads::parallel_tad(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
@ -758,7 +758,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
@ -791,7 +791,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -828,7 +828,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
|
@ -894,7 +894,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input->lengthOf());
|
sd::Threads::parallel_for(func, 0, input->lengthOf());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
@ -918,7 +918,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
|
@ -993,7 +993,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -1010,7 +1010,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
@ -1032,7 +1032,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -1059,7 +1059,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
@ -1081,7 +1081,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
//sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace helpers {
|
||||||
output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
|
output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1);
|
sd::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
|
|
|
@ -425,7 +425,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
|
sd::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads), FLOAT_TYPES);
|
||||||
|
|
||||||
|
@ -577,7 +577,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
|
sd::Threads::parallel_tad(func, 0, numTargets, 1, numThreads);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
|
@ -136,7 +136,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func,0, numOfSubArrs);
|
sd::Threads::parallel_tad(func,0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -168,7 +168,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func,0, numOfSubArrs);
|
sd::Threads::parallel_tad(func,0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -228,7 +228,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
|
|
||||||
delete []offsets;
|
delete []offsets;
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
|
sd::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
// --------------------------------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
|
@ -115,7 +115,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input.lengthOf());
|
sd::Threads::parallel_for(func, 0, input.lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis) {
|
void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis) {
|
||||||
|
|
|
@ -184,7 +184,7 @@ static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, ncols);
|
sd::Threads::parallel_tad(func, 0, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -303,7 +303,7 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, ncols);
|
sd::Threads::parallel_tad(func, 0, ncols);
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K]
|
gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K]
|
||||||
|
|
|
@ -43,7 +43,7 @@ static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, c
|
||||||
output.p<T>(i, inArrs[i]->t<T>(0));
|
output.p<T>(i, inArrs[i]->t<T>(0));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
|
sd::Threads::parallel_for(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, c
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs,
|
||||||
outArrs[i]->p<T>(0, input.t<T>(i));
|
outArrs[i]->p<T>(0, input.t<T>(i));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
|
sd::Threads::parallel_for(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -163,7 +163,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, target->lengthOf());
|
sd::Threads::parallel_tad(func, 0, target->lengthOf());
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ static void triuBP_(sd::LaunchContext * context, const NDArray& input, const NDA
|
||||||
dOdI.t<T>(i) = static_cast<T>(1.f);
|
dOdI.t<T>(i) = static_cast<T>(1.f);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, dLen);
|
sd::Threads::parallel_for(func, 0, dLen);
|
||||||
|
|
||||||
// FIXME: !!!
|
// FIXME: !!!
|
||||||
gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO
|
gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO
|
||||||
|
@ -68,7 +68,7 @@ static void trace_(const NDArray& input, NDArray& output) {
|
||||||
for (auto i = start; i < stop; i++)
|
for (auto i = start; i < stop; i++)
|
||||||
output.p(i, setOfSubArrs.at(i)->getTrace());
|
output.p(i, setOfSubArrs.at(i)->getTrace());
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, setOfSubArrs.size());
|
sd::Threads::parallel_for(func, 0, setOfSubArrs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) {
|
void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) {
|
||||||
|
@ -211,7 +211,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
else { // REFLECT and SYMMETRIC cases
|
else { // REFLECT and SYMMETRIC cases
|
||||||
|
|
||||||
|
@ -237,7 +237,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -606,7 +606,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zLen);
|
sd::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -654,7 +654,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
|
||||||
output->p(e, input->e<T>(indices->e<Nd4jLong>(e)));
|
output->p(e, input->e<T>(indices->e<Nd4jLong>(e)));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, indices->lengthOf());
|
sd::Threads::parallel_for(func, 0, indices->lengthOf());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -670,7 +670,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -694,7 +694,7 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
sd::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -714,7 +714,7 @@ void eye(sd::LaunchContext * context, NDArray& output) {
|
||||||
arrs.at(i)->setIdentity();
|
arrs.at(i)->setIdentity();
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, arrs.size());
|
sd::Threads::parallel_tad(func, 0, arrs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -772,7 +772,7 @@ void scatterUpdate(sd::LaunchContext * context, NDArray& input, NDArray& updates
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, indices.size());
|
sd::Threads::parallel_tad(func, 0, indices.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -792,7 +792,7 @@ void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, len);
|
sd::Threads::parallel_for(func, 0, len);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -824,7 +824,7 @@ static void mergeMaxIndex_(const std::vector<NDArray*>& inArrs, NDArray& output)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
sd::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
@ -850,7 +850,7 @@ static void mergeMax_(const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
sd::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeMax(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeMax(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
@ -875,7 +875,7 @@ static void mergeAvg_(const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
sd::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeAvg(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeAvg(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
@ -900,7 +900,7 @@ static void mergeAdd_(const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, x->lengthOf());
|
sd::Threads::parallel_for(func, 0, x->lengthOf());
|
||||||
}
|
}
|
||||||
void mergeAdd(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeAdd(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES);
|
||||||
|
@ -934,7 +934,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>&
|
||||||
*listOfInSubArrs.at(i) *= normClip / iNormActual;
|
*listOfInSubArrs.at(i) *= normClip / iNormActual;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
sd::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -963,7 +963,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>&
|
||||||
*outputSubArr *= clipNorm / iNormActual;
|
*outputSubArr *= clipNorm / iNormActual;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
sd::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1079,7 +1079,7 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g
|
||||||
gradISubArr->assign(gradOSubArr);
|
gradISubArr->assign(gradOSubArr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
|
sd::Threads::parallel_tad(func, 0, gradISubArrs.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1215,7 +1215,7 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, outLen);
|
sd::Threads::parallel_for(func, 0, outLen);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1);
|
sd::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
|
sd::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) {
|
int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) {
|
||||||
|
|
|
@ -68,7 +68,7 @@ static void zeta_(sd::LaunchContext * context, const NDArray& x, const NDArray&
|
||||||
z.p(i, zetaScalar<T>(x.e<T>(i), q.e<T>(i)));
|
z.p(i, zetaScalar<T>(x.e<T>(i), q.e<T>(i)));
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, xLen);
|
sd::Threads::parallel_for(func, 0, xLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) {
|
void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) {
|
||||||
|
|
|
@ -77,7 +77,7 @@ void FORCEINLINE cross(sd::LaunchContext * context, NDArray *a, NDArray *b, NDAr
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, tads);
|
sd::Threads::parallel_tad(func, 0, tads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output);
|
void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output);
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue