124 lines
3.4 KiB
C++
124 lines
3.4 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author raver119@gmail.com
|
||
|
//
|
||
|
|
||
|
#include <VariableSpace.h>
|
||
|
#include <Stash.h>
|
||
|
#include <graph/SessionLocalStorage.h>
|
||
|
|
||
|
namespace nd4j {
|
||
|
namespace graph {
|
||
|
SessionLocalStorage::SessionLocalStorage(VariableSpace* variableSpace, Stash* stash) {
|
||
|
// we start from 1, since key 0 holds original VariableSpace
|
||
|
_sessionCounter.store(1);
|
||
|
_variableSpace = variableSpace;
|
||
|
_stash = stash;
|
||
|
}
|
||
|
|
||
|
VariableSpace* SessionLocalStorage::localVariableSpace(Nd4jLong sessionId) {
|
||
|
_mutex.lock();
|
||
|
auto varSpace = _threadVariableSpace.at(sessionId);
|
||
|
_mutex.unlock();
|
||
|
|
||
|
return varSpace;
|
||
|
}
|
||
|
|
||
|
VariableSpace* SessionLocalStorage::localVariableSpace() {
|
||
|
return localVariableSpace(getSessionId());
|
||
|
}
|
||
|
|
||
|
SessionLocalStorage::~SessionLocalStorage() {
|
||
|
for (const auto & v: _threadVariableSpace) {
|
||
|
delete v.second;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
Nd4jLong SessionLocalStorage::getThreadId() {
|
||
|
#ifdef __APPLE__
|
||
|
// syscall?
|
||
|
#elif _WIN32
|
||
|
// some win32api
|
||
|
#else
|
||
|
// syscall!
|
||
|
#endif
|
||
|
auto id=std::this_thread::get_id();
|
||
|
uint64_t* ptr=(uint64_t*) &id;
|
||
|
return (*ptr);
|
||
|
}
|
||
|
|
||
|
int SessionLocalStorage::numberOfSessions() {
|
||
|
_mutex.lock();
|
||
|
int size = (int) _threadSession.size();
|
||
|
_mutex.unlock();
|
||
|
return size;
|
||
|
}
|
||
|
|
||
|
void SessionLocalStorage::endSession(Nd4jLong sessionId) {
|
||
|
// we should delete specific holders here
|
||
|
_mutex.lock();
|
||
|
auto vs = _threadVariableSpace[sessionId];
|
||
|
_threadVariableSpace.erase(sessionId);
|
||
|
|
||
|
delete vs;
|
||
|
_mutex.unlock();
|
||
|
}
|
||
|
|
||
|
void SessionLocalStorage::endSession() {
|
||
|
auto tid = getThreadId();
|
||
|
|
||
|
_mutex.lock();
|
||
|
|
||
|
auto ntid = _threadSession[tid];
|
||
|
_threadSession.erase(tid);
|
||
|
|
||
|
_mutex.unlock();
|
||
|
|
||
|
endSession(ntid);
|
||
|
}
|
||
|
|
||
|
Nd4jLong SessionLocalStorage::getSessionId() {
|
||
|
auto tid = getThreadId();
|
||
|
|
||
|
_mutex.lock();
|
||
|
auto ntid = _threadSession[tid];
|
||
|
|
||
|
_mutex.unlock();
|
||
|
|
||
|
return ntid;
|
||
|
}
|
||
|
|
||
|
Nd4jLong nd4j::graph::SessionLocalStorage::startSession() {
|
||
|
auto tid = getThreadId();
|
||
|
|
||
|
nd4j_debug("Adding ThreadId: %i;\n", (int) tid);
|
||
|
Nd4jLong ntid = _sessionCounter++;
|
||
|
_mutex.lock();
|
||
|
|
||
|
_threadSession[tid] = ntid;
|
||
|
_threadVariableSpace[ntid] = _variableSpace->clone();
|
||
|
|
||
|
_mutex.unlock();
|
||
|
|
||
|
return ntid;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|