parent
b75bac750d
commit
b597fb942b
|
@ -64,7 +64,11 @@ namespace helpers {
|
|||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size();
|
||||
|
||||
NDArray::prepareSpecialUse({outArr}, inArrs);
|
||||
NDArray::prepareSpecialUse({outArr}, {});
|
||||
|
||||
// FIXME: !!!
|
||||
for (auto v:inArrs)
|
||||
NDArray::prepareSpecialUse({}, {v});
|
||||
|
||||
std::vector<void const*> inputList(inArrs.size());
|
||||
std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
|
||||
|
@ -88,8 +92,11 @@ namespace helpers {
|
|||
}
|
||||
manager.synchronize();
|
||||
|
||||
NDArray::registerSpecialUse({outArr}, inArrs);
|
||||
NDArray::registerSpecialUse({outArr}, {});
|
||||
|
||||
// FIXME: !!!
|
||||
for (auto v:inArrs)
|
||||
NDArray::registerSpecialUse({}, {v});
|
||||
}
|
||||
|
||||
void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||
|
|
Loading…
Reference in New Issue