cavis/libnd4j/include/graph/impl/SessionLocalStorage.cpp

124 lines
3.4 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <VariableSpace.h>
#include <Stash.h>
#include <graph/SessionLocalStorage.h>
namespace nd4j {
namespace graph {
SessionLocalStorage::SessionLocalStorage(VariableSpace* variableSpace, Stash* stash) {
// we start from 1, since key 0 holds original VariableSpace
_sessionCounter.store(1);
_variableSpace = variableSpace;
_stash = stash;
}
VariableSpace* SessionLocalStorage::localVariableSpace(Nd4jLong sessionId) {
_mutex.lock();
auto varSpace = _threadVariableSpace.at(sessionId);
_mutex.unlock();
return varSpace;
}
VariableSpace* SessionLocalStorage::localVariableSpace() {
return localVariableSpace(getSessionId());
}
SessionLocalStorage::~SessionLocalStorage() {
for (const auto & v: _threadVariableSpace) {
delete v.second;
}
}
Nd4jLong SessionLocalStorage::getThreadId() {
#ifdef __APPLE__
// syscall?
#elif _WIN32
// some win32api
#else
// syscall!
#endif
auto id=std::this_thread::get_id();
uint64_t* ptr=(uint64_t*) &id;
return (*ptr);
}
int SessionLocalStorage::numberOfSessions() {
_mutex.lock();
int size = (int) _threadSession.size();
_mutex.unlock();
return size;
}
void SessionLocalStorage::endSession(Nd4jLong sessionId) {
// we should delete specific holders here
_mutex.lock();
auto vs = _threadVariableSpace[sessionId];
_threadVariableSpace.erase(sessionId);
delete vs;
_mutex.unlock();
}
void SessionLocalStorage::endSession() {
auto tid = getThreadId();
_mutex.lock();
auto ntid = _threadSession[tid];
_threadSession.erase(tid);
_mutex.unlock();
endSession(ntid);
}
Nd4jLong SessionLocalStorage::getSessionId() {
auto tid = getThreadId();
_mutex.lock();
auto ntid = _threadSession[tid];
_mutex.unlock();
return ntid;
}
Nd4jLong nd4j::graph::SessionLocalStorage::startSession() {
auto tid = getThreadId();
nd4j_debug("Adding ThreadId: %i;\n", (int) tid);
Nd4jLong ntid = _sessionCounter++;
_mutex.lock();
_threadSession[tid] = ntid;
_threadVariableSpace[ntid] = _variableSpace->clone();
_mutex.unlock();
return ntid;
}
}
}