- 3d loops parallelism fix (#135)

- additional check for maxMasterThreads <= maxThreads

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-19 16:50:08 +03:00 committed by GitHub
parent 3d8f6d50a1
commit 8b877a8ddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 2 deletions

View File

@ -61,6 +61,7 @@ namespace nd4j {
std::string omp(omp_threads);
int val = std::stoi(omp);
_maxThreads.store(val);
_maxMasterThreads.store(val);
} catch (std::invalid_argument &e) {
// just do nothing
} catch (std::out_of_range &e) {
@ -100,6 +101,11 @@ namespace nd4j {
}
}
if (_maxMasterThreads.load() > _maxThreads.load()) {
nd4j_printf("Warning! MAX_MASTER_THREADS > MAX_THREADS, tuning them down to match each other\n","");
_maxMasterThreads.store(_maxThreads.load());
}
/**
* If this env var is defined - we'll disallow use of platform-specific helpers (mkldnn, cudnn, etc)
*/

View File

@ -492,7 +492,7 @@ namespace samediff {
auto itersY = delta_y / incY;
auto itersZ = delta_z / incZ;
numThreads = 1; //ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ);
numThreads = ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ);
if (numThreads == 1) {
// loop is too small - executing function as is
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);

View File

@ -59,6 +59,41 @@ public:
fflush(stdout);
}
};
/*
TEST_F(PlaygroundTests, test_s_1) {
auto x = NDArrayFactory::create<float>('c', {32,112,112,16});
auto y = NDArrayFactory::create<float>('c', {16});
auto z = x.ulike();
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setInputArray(1, &y);
ctx.setOutputArray(0, &z);
std::vector<Nd4jLong> values;
nd4j::ops::biasadd op;
op.execute(&ctx);
for (int e = 0; e < 1000; e++) {
auto timeStart = std::chrono::system_clock::now();
op.execute(&ctx);
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
values.emplace_back(outerTime);
}
std::sort(values.begin(), values.end());
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
}
*/
/*
TEST_F(PlaygroundTests, test_s_1) {
auto t = ::runLightBenchmarkSuit(true);

View File

@ -32,7 +32,9 @@ using namespace nd4j::graph;
class ThreadsTests : public testing::Test {
public:
ThreadsTests() {
nd4j_printf("\n","");
}
};
TEST_F(ThreadsTests, th_test_1) {
@ -84,6 +86,18 @@ TEST_F(ThreadsTests, th_test_3) {
ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64));
}
TEST_F(ThreadsTests, th_test_5) {
ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112));
ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112));
for (auto e = 0; e < 6; e++) {
auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1);
nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX());
}
}
TEST_F(ThreadsTests, th_test_4) {
// typical conv cases
ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3));