parent
ca96a13ed0
commit
3bb22a6ff8
|
@ -411,8 +411,24 @@ namespace sd {
|
||||||
// }
|
// }
|
||||||
// else {
|
// else {
|
||||||
if (indices.size()) {
|
if (indices.size()) {
|
||||||
auto sub = (*x)(indices, true, true);
|
Nd4jLong* subArrShapeInfo = nullptr;
|
||||||
z->assign(sub);
|
ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong);
|
||||||
|
Nd4jLong offset;
|
||||||
|
|
||||||
|
shape::calcSubArrShapeInfoAndOffset(indices.data(), x->getShapeInfo(), subArrShapeInfo, offset, true, true);
|
||||||
|
auto subArrShapeInfoPack = ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({z}, {x});
|
||||||
|
|
||||||
|
NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign,
|
||||||
|
x->bufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.primary()),
|
||||||
|
x->specialBufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.special()),
|
||||||
|
z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(),
|
||||||
|
nullptr, nullptr, nullptr, true);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({z}, {x});
|
||||||
|
|
||||||
|
RELEASE(subArrShapeInfo, block.getWorkspace());
|
||||||
}
|
}
|
||||||
else if (!z->isEmpty()){
|
else if (!z->isEmpty()){
|
||||||
z->assign(x->e(0));
|
z->assign(x->e(0));
|
||||||
|
|
Loading…
Reference in New Issue