diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index deb7828d8..ecd4bcaca 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace sd { NDArrayList::NDArrayList(int height, bool expandable) { @@ -144,25 +145,38 @@ namespace sd { NDArray* NDArrayList::stack() { // FIXME: this is bad for perf, but ok as poc - sd::ops::stack op; - std::vector inputs; - std::vector targs; - std::vector iargs({0}); - std::vector bargs; + int numElements = _elements.load(); - + std::vector inputs(numElements); for (int e = 0; e < numElements; e++) { _chunks[e]->syncToDevice(); - inputs.emplace_back(_chunks[e]); + inputs[e] = _chunks[e]; } - iargs.push_back(_axis); + auto inShapeInfo = inputs[0]->getShapeInfo(); + int rank = shape::rank(inShapeInfo); + NDArray* array = nullptr; - auto result = op.evaluate(inputs); + if (shape::isEmpty(inShapeInfo)) { + switch (rank) { + case 0: { + if (numElements == 1) { + array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + } else { + array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ; + } + } + } + } + else{ + std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); + outShape.insert(outShape.begin(), (Nd4jLong) numElements); + array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + } + + ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0); - auto array = new NDArray(result.at(0)->dup()); - - return array; + return array; } std::pair& NDArrayList::id() { diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index 52377739b..fac209905 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -14,10 +14,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author raver119@gmail.com -// modified by sgazeos@gmail.com with backprop implementation. -// + // + // @author raver119@gmail.com + // modified by sgazeos@gmail.com with backprop implementation. + // #include #if NOT_EXCLUDED(OP_floormod) @@ -31,7 +31,7 @@ namespace sd { auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!"); auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z); @@ -46,15 +46,15 @@ namespace sd { DECLARE_TYPES(floormod) { getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); } DECLARE_TYPES(floormod_bp) { getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); } CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) { @@ -66,11 +66,11 @@ namespace sd { auto gradY = OUTPUT_VARIABLE(1); gradX->assign(epsNext); - sd::ops::floormod op; - auto tmpResult(op.evaluate({x, y})); + NDArray temp(*epsNext); + BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp); if (gradY->rankOf() == gradX->rankOf()) - epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY); + epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY); else // epsNext is greater than gradY { std::vector dims(epsNext->rankOf() * 2); @@ -78,7 +78,7 @@ namespace sd { for (Nd4jLong d = 0; d < gap; d++) { dims[d * 2 + 1] = 1; } - auto tempIn((*tmpResult.at(0))(dims)); + auto tempIn((temp)(dims)); (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); } return Status::OK(); @@ -92,8 +92,8 @@ namespace sd { // eps always has shape of x // grad always has shape of y - Nd4jLong *shapeE; - Nd4jLong *shapeG; + Nd4jLong* shapeE; + Nd4jLong* shapeG; COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp index f88710904..40e080a8f 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author raver119@gmail.com -// + // + // @author raver119@gmail.com + // #include #if NOT_EXCLUDED(OP_split_string) @@ -60,7 +60,7 @@ namespace sd { // filling output indices for (uint64_t f = 0; f < cnt; f++) { - for (auto v: icoords) + for (auto v : icoords) indices->p(ic++, v); // last index @@ -75,12 +75,12 @@ namespace sd { for (auto e = 0L; e < input->lengthOf(); e++) { auto split = StringUtils::split(input->e(e), d); - for (const auto &s:split) + for (const auto& s : split) strings.emplace_back(s); } // now once we have all strings in single vector time to fill - auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings); + auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext()); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); // for CUDA mostly @@ -129,9 +129,9 @@ namespace sd { DECLARE_TYPES(compat_string_split) { getOpDescriptor() - ->setAllowedInputTypes({ALL_STRINGS}) - ->setAllowedOutputTypes(0, {ALL_INDICES}) - ->setAllowedOutputTypes(1, {ALL_STRINGS}); + ->setAllowedInputTypes({ ALL_STRINGS }) + ->setAllowedOutputTypes(0, { ALL_INDICES }) + ->setAllowedOutputTypes(1, { ALL_STRINGS }); } } } diff --git a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp index 244083a03..7d8eeec3a 100644 --- a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp @@ -68,6 +68,7 @@ namespace sd { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -201,6 +202,7 @@ namespace sd { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp index 0c05de0ba..a29bd1cf2 100644 --- a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp @@ -73,6 +73,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -216,6 +217,7 @@ DECLARE_SHAPE_FN(huber_loss) { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp index dc889d5c9..99140a394 100644 --- a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp @@ -70,6 +70,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -206,6 +207,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp index 9a00b4eb4..20e03e92b 100644 --- a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp @@ -74,6 +74,7 @@ namespace ops { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = *weights * E.lengthOf(); else @@ -209,6 +210,7 @@ namespace ops { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp index 312a32674..f8006a3ed 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp @@ -143,6 +143,7 @@ namespace sd { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else @@ -282,6 +283,7 @@ namespace sd { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp index c5925fe90..b0ccf968b 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp @@ -67,6 +67,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else @@ -200,6 +201,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp index 4d3c5749c..28d66bc93 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp @@ -78,6 +78,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else @@ -219,6 +220,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { } case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array NDArray sum; + sum.setContext(block.launchContext()); if (weights->isScalar()) sum = (*weights) * E.lengthOf(); else diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp index 43a7c7c75..1decc65f0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author sgazeos@gmail.com -// + // + // @author sgazeos@gmail.com + // #include #include @@ -29,24 +29,24 @@ namespace sd { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector()); + auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector(), block.launchContext()); BROADCAST_CHECK_EMPTY(x, y, (&z0)); - + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); bitcast res; - auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false); + auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false); if (tZ != &z0) { delete tZ; } - + return status; } DECLARE_TYPES(compare_and_bitpack) { getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::UINT8); + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::UINT8); } DECLARE_SHAPE_FN(compare_and_bitpack) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index d3ccff82a..f8a4c5c6e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// Created by george@skymind.io on 26.01.2018. -// + // + // Created by george@skymind.io on 26.01.2018. + // #include #if NOT_EXCLUDED(OP_normalize_moments) @@ -34,7 +34,7 @@ namespace sd { auto resVariances = OUTPUT_VARIABLE(1); // FIXME: double? - NDArray shift = NDArrayFactory::create(0.); + NDArray shift = NDArrayFactory::create(0., block.launchContext()); if (block.getTArguments()->size() > 0) { shift.assign(T_ARG(0)); @@ -47,7 +47,7 @@ namespace sd { squareMeans.applyTransform(transform::Square, squareMeans, nullptr); variances->applyScalarArr(scalar::Divide, *counts, tempVariances); -// tempVariances.printIndexedBuffer("varianced divided by count"); + // tempVariances.printIndexedBuffer("varianced divided by count"); tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances); if (shift.e(0) != 0) { @@ -75,8 +75,8 @@ namespace sd { DECLARE_TYPES(normalize_moments) { getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ ALL_FLOATS }); } } diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index 6dec96739..94df6b32d 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -49,8 +49,8 @@ namespace sd { bool disposable = false; if (min == nullptr && max == nullptr && block.numT() >= 2) { - min = NDArrayFactory::create_(dtype); - max = NDArrayFactory::create_(dtype); + min = NDArrayFactory::create_(dtype, block.launchContext()); + max = NDArrayFactory::create_(dtype, block.launchContext()); min->p(0, T_ARG(0)); max->p(0, T_ARG(1)); disposable = true; diff --git a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp index 82dd8c36e..cf3675122 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -// -// @author George A. Shulinok , created on 4/18/2019. -// + // + // @author George A. Shulinok , created on 4/18/2019. + // #include #if NOT_EXCLUDED(OP_barnes_symmetrized) @@ -25,20 +25,20 @@ #include namespace sd { -namespace ops { - NDArray* rowCountsPtr = nullptr; + namespace ops { + NDArray* rowCountsPtr = nullptr; - CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto valP = INPUT_VARIABLE(2); + CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto valP = INPUT_VARIABLE(2); auto N = rowP->lengthOf() - 1; - auto outputRows = OUTPUT_VARIABLE(0); + auto outputRows = OUTPUT_VARIABLE(0); auto outputCols = OUTPUT_VARIABLE(1); auto outputVals = OUTPUT_VARIABLE(2); - if (block.getIArguments()->size() > 0) - N = INT_ARG(0); + if (block.getIArguments()->size() > 0) + N = INT_ARG(0); if (rowCountsPtr) { helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr); @@ -46,33 +46,33 @@ namespace ops { return Status::OK(); } return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data."); - } + } - DECLARE_TYPES(barnes_symmetrized) { - getOpDescriptor() - ->setAllowedInputTypes(0, {DataType::INT32}) - ->setAllowedInputTypes(1, {DataType::INT32}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(1, {DataType::INT32}) - ->setAllowedOutputTypes(1, {DataType::INT32}) - ->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setSameMode(false); - } + DECLARE_TYPES(barnes_symmetrized) { + getOpDescriptor() + ->setAllowedInputTypes(0, { DataType::INT32 }) + ->setAllowedInputTypes(1, { DataType::INT32 }) + ->setAllowedInputTypes(2, { ALL_INTS, ALL_FLOATS }) + ->setAllowedOutputTypes(1, { DataType::INT32 }) + ->setAllowedOutputTypes(1, { DataType::INT32 }) + ->setAllowedOutputTypes(2, { ALL_INTS, ALL_FLOATS }) + ->setSameMode(false); + } - DECLARE_SHAPE_FN(barnes_symmetrized) { - auto valPShapeInfo = inputShape->at(2); + DECLARE_SHAPE_FN(barnes_symmetrized) { + auto valPShapeInfo = inputShape->at(2); Nd4jLong* outShapeInfo; - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); auto N = rowP->lengthOf() - 1; if (block.getIArguments()->size() > 0) N = INT_ARG(0); auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0)); - NDArray* rowCounts = NDArrayFactory::create_('c', {N}); //rowP->dup(); + NDArray* rowCounts = NDArrayFactory::create_('c', { N }, block.launchContext()); //rowP->dup(); //srowCounts->assign(0); Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); rowCounts->syncToHost(); -// rowCounts->printBuffer("Row Counts"); + // rowCounts->printBuffer("Row Counts"); if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len."); rowCountsPtr = rowCounts; //ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); @@ -80,13 +80,13 @@ namespace ops { // outShapeInfo[2] = len; // ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c'); //outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace()); - outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', {1, len}, block.getWorkspace()); - auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, len}, block.getWorkspace()); - auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, N + 1}, block.getWorkspace()); - return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); - } + outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.getWorkspace()); + auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.getWorkspace()); + auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.getWorkspace()); + return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); + } -} + } } #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8466631da..8938a98f9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -142,7 +142,7 @@ namespace helpers { const int rowNum = input->rows(); const int columnNum = input->columns(); - NDArray determinant = NDArrayFactory::create(1.f); + NDArray determinant = NDArrayFactory::create(1.f, context); NDArray compoundMatrix = *input; // copy NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides permutationMatrix.setIdentity(); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp index 2ea18a79d..1f980e553 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp @@ -39,7 +39,7 @@ namespace helpers { template NDArray vmul(NDArray const& v, int n) { - NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n); + NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n); T const* vBuf = v.getDataBuffer()->primaryAsT(); T* resBuf = res.dataBuffer()->primaryAsT(); auto interloop = PRAGMA_THREADS_FOR_2D { @@ -61,7 +61,7 @@ namespace helpers { std::vector q(M); NDArray z = *matrix; - NDArray e('c', {M}, DataTypeUtils::fromT()); // two internal buffers and scalar for squared norm + NDArray e('c', {M}, DataTypeUtils::fromT(), Q->getContext()); // two internal buffers and scalar for squared norm for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number e.nullify(); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp index 7e0b07da0..78b06d71e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp @@ -69,9 +69,9 @@ namespace helpers { auto trial = (*input)(e, dimsToExclude); // fill up the first k elements - NDArray topValues = NDArrayFactory::create('c', {k}); - NDArray sortedVals = NDArrayFactory::create('c', {k}); - NDArray topIndices = NDArrayFactory::create('c', {k}); + NDArray topValues = NDArrayFactory::create('c', {k}, input->getContext()); + NDArray sortedVals = NDArrayFactory::create('c', {k}, input->getContext()); + NDArray topIndices = NDArrayFactory::create('c', {k}, input->getContext()); for (uint pos = 0; pos < k; ++pos) { topIndices.t(pos) = pos; topValues.t(pos) = trial.t(pos); @@ -144,7 +144,7 @@ namespace helpers { for (int i = 0; i < input->rankOf() - 1; i++) shapeI[i] = input->sizeAt(i); shapeI[input->rankOf() - 1] = k; - std::unique_ptr indices(NDArrayFactory::create_(input->ordering(), shapeI)); + std::unique_ptr indices(NDArrayFactory::create_(input->ordering(), shapeI, context)); NDArray* values = nullptr; int status = topKFunctor(context, input, values, indices.get(), k, true); result->assign(0); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index 51af14fc4..c6123d6da 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -112,7 +112,7 @@ namespace sd { int numThreads = 256; int numBlocks = sd::math::nd4j_max(256, sd::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); int workspaceSize = numBlocks * numBins; - auto tmp = NDArrayFactory::create('c', {workspaceSize}); + auto tmp = NDArrayFactory::create('c', {workspaceSize}, context); histogramKernel<<getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast(min_val), reinterpret_cast(max_val)); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu index 9817471bb..6d6ec95ed 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu @@ -25,7 +25,7 @@ namespace ops { namespace helpers { typedef NDArray ColorTable_t; - static NDArray DefaultColorTable(int depth) { + static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) { //std::vector> colorTable; const Nd4jLong kDefaultTableLength = 10; const Nd4jLong kDefaultChannelLength = 4; @@ -40,7 +40,7 @@ namespace helpers { 0, 0, 0.5, 1, // 7: navy blue 0, 1, 1, 1, // 8: aqua 1, 0, 1, 1 // 9: fuchsia - }, DataType::FLOAT32); + }, DataType::FLOAT32, context); if (depth == 1) { colorTable.assign(1.f); // all to white when black and white colors @@ -144,7 +144,7 @@ namespace helpers { auto channels = images->sizeAt(3); auto stream = context->getCudaStream(); auto boxSize = boxes->sizeAt(1); - NDArray colorsTable = DefaultColorTable(channels); + NDArray colorsTable = DefaultColorTable(channels, context); if ((colors != nullptr && colors->lengthOf() > 0)) { colorsTable = *colors; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 5ed534cb6..e6d9a27b1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -188,7 +188,7 @@ namespace helpers { static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {boxes, scales}); - std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext()); + std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext()); NDArray scores(*scales); Nd4jPointer extras[2] = {nullptr, stream}; @@ -198,7 +198,7 @@ namespace helpers { indices->tickWriteDevice(); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); indices->tickWriteDevice(); - NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}); + NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}, context); int numSelected = 0; int numBoxes = boxes->sizeAt(0); auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); @@ -347,8 +347,8 @@ namespace helpers { scores->syncToDevice(); } - NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); - NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}); + NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext()); + NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}, context); NDArray selectedScores(*scores); Nd4jPointer extras[2] = {nullptr, stream}; auto indexBuf = indices.dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 7630694e1..2ca731912 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -598,7 +598,7 @@ namespace helpers { static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { auto n = input->sizeAt(-1); auto stream = context->getCudaStream(); - NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // ('c', {n}); + NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // ('c', {n}); iota.linspace(0); iota.syncToDevice(); output->assign(input); // fill up output tensor with zeros @@ -631,7 +631,7 @@ namespace helpers { // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1); + auto det = NDArrayFactory::create(1, context); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); dim3 launchDims(256, 256, 1024); @@ -677,7 +677,7 @@ namespace helpers { dtype = DataType::FLOAT32; auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); - auto det = NDArrayFactory::create(1); + auto det = NDArrayFactory::create(1, context); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); dim3 launchDims(256, 256, 1024); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu index d0d5fddd5..394840376 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu @@ -110,7 +110,7 @@ namespace helpers { auto resR = fullMatricies?R->ulike():matrix->ulike(); std::vector q(M); NDArray z = *matrix; - NDArray e('c', {M}, DataTypeUtils::fromT()); // two internal buffers and scalar for squared norm + NDArray e('c', {M}, DataTypeUtils::fromT(), context); // two internal buffers and scalar for squared norm for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number e.nullify(); z = matrixMinor(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix) @@ -177,4 +177,3 @@ namespace helpers { } } } - diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index 208bf764b..e7baf2370 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -167,8 +167,8 @@ namespace sd { auto stream = context->getCudaStream(); indices->syncToHost(); Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -209,8 +209,8 @@ namespace sd { // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); output->assign(DataTypeUtils::infOrMax()); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes); classesRangesBegs.assign(indices->lengthOf()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index fa8882190..76036a5e6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -158,8 +158,8 @@ namespace helpers { static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -198,8 +198,8 @@ namespace helpers { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); @@ -314,8 +314,8 @@ namespace helpers { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -367,8 +367,8 @@ namespace helpers { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index b83d37567..0133b3b11 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -161,8 +161,8 @@ namespace helpers { static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); output->assign(DataTypeUtils::infOrMax()); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -202,8 +202,8 @@ namespace helpers { static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); output->assign(DataTypeUtils::infOrMax()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index baec75b9e..d08f79817 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -122,8 +122,8 @@ namespace helpers { static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); output->assign(1); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -160,8 +160,8 @@ namespace helpers { static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index 7d85a0ea6..f9b6eaad0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -86,8 +86,8 @@ namespace helpers { static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); @@ -207,8 +207,8 @@ namespace helpers { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 6e1e3fca8..56d53710f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -162,8 +162,8 @@ namespace helpers { static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { auto stream = context->getCudaStream(); Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); + NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -201,8 +201,8 @@ namespace helpers { static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); classesRangesBegs.assign(indices->lengthOf()); diff --git a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp index ff9ae3144..5989f5246 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp @@ -41,7 +41,7 @@ namespace helpers { length += array->lengthOf(); pos++; } - NDArray arrayFull('c', {length}, sd::DataType::INT32); + NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext()); cContext.setOutputArray(0, &arrayFull); cContext.setIArguments(&axis, 1);