temporary stack fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-03 15:29:21 +03:00 committed by AlexDBlack
parent b75bac750d
commit b597fb942b
1 changed files with 9 additions and 2 deletions

View File

@ -64,7 +64,11 @@ namespace helpers {
const int threadsPerBlock = MAX_NUM_THREADS / 2; const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size(); const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size();
NDArray::prepareSpecialUse({outArr}, inArrs); NDArray::prepareSpecialUse({outArr}, {});
// FIXME: !!!
for (auto v:inArrs)
NDArray::prepareSpecialUse({}, {v});
std::vector<void const*> inputList(inArrs.size()); std::vector<void const*> inputList(inArrs.size());
std::vector<Nd4jLong const*> inputShapeList(inArrs.size()); std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
@ -88,8 +92,11 @@ namespace helpers {
} }
manager.synchronize(); manager.synchronize();
NDArray::registerSpecialUse({outArr}, inArrs); NDArray::registerSpecialUse({outArr}, {});
// FIXME: !!!
for (auto v:inArrs)
NDArray::registerSpecialUse({}, {v});
} }
void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) { void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {