parent
b75bac750d
commit
b597fb942b
|
@ -64,7 +64,11 @@ namespace helpers {
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size();
|
const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size();
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({outArr}, inArrs);
|
NDArray::prepareSpecialUse({outArr}, {});
|
||||||
|
|
||||||
|
// FIXME: !!!
|
||||||
|
for (auto v:inArrs)
|
||||||
|
NDArray::prepareSpecialUse({}, {v});
|
||||||
|
|
||||||
std::vector<void const*> inputList(inArrs.size());
|
std::vector<void const*> inputList(inArrs.size());
|
||||||
std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
|
std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
|
||||||
|
@ -88,8 +92,11 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
|
|
||||||
NDArray::registerSpecialUse({outArr}, inArrs);
|
NDArray::registerSpecialUse({outArr}, {});
|
||||||
|
|
||||||
|
// FIXME: !!!
|
||||||
|
for (auto v:inArrs)
|
||||||
|
NDArray::registerSpecialUse({}, {v});
|
||||||
}
|
}
|
||||||
|
|
||||||
void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||||
|
|
Loading…
Reference in New Issue