cavis/libnd4j/include/execution/impl/CallableInterface.cpp

213 lines
6.4 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/CallableInterface.h>
#include <helpers/logger.h>
namespace samediff {
CallableInterface::CallableInterface() {
// initial state is available
_available = true;
_filled = false;
_finished = false;
}
bool CallableInterface::available() {
return _available.load();
}
void CallableInterface::markUnavailable() {
_available = false;
}
void CallableInterface::markAvailable() {
_available = true;
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_DO func) {
_function_do = std::move(func);
_branch = 0;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_1D func, int64_t startX, int64_t stopX, int64_t incX) {
_function_1d = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_branch = 1;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_2D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y) {
_function_2d = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_arguments[3] = start_y;
_arguments[4] = stop_y;
_arguments[5] = inc_y;
_branch = 2;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_3D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) {
_function_3d = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_arguments[3] = start_y;
_arguments[4] = stop_y;
_arguments[5] = inc_y;
_arguments[6] = start_z;
_arguments[7] = stop_z;
_arguments[8] = inc_z;
_branch = 3;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, int64_t *lptr, FUNC_RL func, int64_t startX, int64_t stopX, int64_t incX) {
_function_rl = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_lptr = lptr;
_branch = 4;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, double *dptr, FUNC_RD func, int64_t startX, int64_t stopX, int64_t incX) {
_function_rd = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_dptr = dptr;
_branch = 5;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::waitForTask() {
// block until task is available
std::unique_lock<std::mutex> lock(_ms);
_starter.wait(lock, [&]{ return _filled.load(); });
}
void CallableInterface::waitForCompletion() {
//while (!_finished.load());
// block until finished
std::unique_lock<std::mutex> lock(_mf);
_finisher.wait(lock, [&] { return _finished.load(); });
}
void CallableInterface::finish() {
// mark as finished
{
std::unique_lock<std::mutex> l(_mf);
_finished.store(true);
}
_finisher.notify_one();
}
void CallableInterface::execute() {
// mark it as consumed
_filled = false;
// actually executing op
switch (_branch) {
case 0:
_function_do(_thread_id, _num_threads);
break;
case 1:
_function_1d(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
break;
case 2:
_function_2d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5]);
break;
case 3:
_function_3d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5], _arguments[6], _arguments[7], _arguments[8]);
break;
case 4:
_lptr[0] = _function_rl(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
break;
case 5:
_dptr[0] = _function_rd(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
break;
}
// notify that thread finished the job
this->finish();
}
}