parent
ca96a13ed0
commit
3bb22a6ff8
|
@ -411,8 +411,24 @@ namespace sd {
|
|||
// }
|
||||
// else {
|
||||
if (indices.size()) {
|
||||
auto sub = (*x)(indices, true, true);
|
||||
z->assign(sub);
|
||||
Nd4jLong* subArrShapeInfo = nullptr;
|
||||
ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong);
|
||||
Nd4jLong offset;
|
||||
|
||||
shape::calcSubArrShapeInfoAndOffset(indices.data(), x->getShapeInfo(), subArrShapeInfo, offset, true, true);
|
||||
auto subArrShapeInfoPack = ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo);
|
||||
|
||||
NDArray::prepareSpecialUse({z}, {x});
|
||||
|
||||
NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign,
|
||||
x->bufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.primary()),
|
||||
x->specialBufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.special()),
|
||||
z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(),
|
||||
nullptr, nullptr, nullptr, true);
|
||||
|
||||
NDArray::registerSpecialUse({z}, {x});
|
||||
|
||||
RELEASE(subArrShapeInfo, block.getWorkspace());
|
||||
}
|
||||
else if (!z->isEmpty()){
|
||||
z->assign(x->e(0));
|
||||
|
|
Loading…
Reference in New Issue