libnd4j fixes for context sync in operation execution (#350)

Signed-off-by: Oleg <oleg.semeniv@gmail.com>
master
Oleh 2020-03-30 16:33:51 +03:00 committed by GitHub
parent 9b3576bc00
commit bf0ddbc06c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 167 additions and 140 deletions

View File

@ -23,6 +23,7 @@
#include <array/NDArrayList.h> #include <array/NDArrayList.h>
#include <helpers/ShapeUtils.h> #include <helpers/ShapeUtils.h>
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/stack.h>
namespace sd { namespace sd {
NDArrayList::NDArrayList(int height, bool expandable) { NDArrayList::NDArrayList(int height, bool expandable) {
@ -144,25 +145,38 @@ namespace sd {
NDArray* NDArrayList::stack() { NDArray* NDArrayList::stack() {
// FIXME: this is bad for perf, but ok as poc // FIXME: this is bad for perf, but ok as poc
sd::ops::stack op;
std::vector<NDArray*> inputs;
std::vector<double> targs;
std::vector<Nd4jLong> iargs({0});
std::vector<bool> bargs;
int numElements = _elements.load(); int numElements = _elements.load();
std::vector<const NDArray*> inputs(numElements);
for (int e = 0; e < numElements; e++) { for (int e = 0; e < numElements; e++) {
_chunks[e]->syncToDevice(); _chunks[e]->syncToDevice();
inputs.emplace_back(_chunks[e]); inputs[e] = _chunks[e];
} }
iargs.push_back(_axis); auto inShapeInfo = inputs[0]->getShapeInfo();
int rank = shape::rank(inShapeInfo);
NDArray* array = nullptr;
auto result = op.evaluate(inputs); if (shape::isEmpty(inShapeInfo)) {
switch (rank) {
case 0: {
if (numElements == 1) {
array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
} else {
array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ;
}
}
}
}
else{
std::vector<Nd4jLong> outShape(inShapeInfo + 1, inShapeInfo + 1 + rank);
outShape.insert(outShape.begin(), (Nd4jLong) numElements);
array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
}
ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0);
auto array = new NDArray(result.at(0)->dup()); return array;
return array;
} }
std::pair<int,int>& NDArrayList::id() { std::pair<int,int>& NDArrayList::id() {

View File

@ -14,10 +14,10 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// modified by sgazeos@gmail.com with backprop implementation. // modified by sgazeos@gmail.com with backprop implementation.
// //
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_floormod) #if NOT_EXCLUDED(OP_floormod)
@ -31,7 +31,7 @@ namespace sd {
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x, y, z);
REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!"); REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!");
auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z); auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z);
@ -46,15 +46,15 @@ namespace sd {
DECLARE_TYPES(floormod) { DECLARE_TYPES(floormod) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY) ->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, DataType::ANY) ->setAllowedInputTypes(1, DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT); ->setAllowedOutputTypes(0, DataType::INHERIT);
} }
DECLARE_TYPES(floormod_bp) { DECLARE_TYPES(floormod_bp) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(DataType::ANY) ->setAllowedInputTypes(DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS}); ->setAllowedOutputTypes({ ALL_FLOATS });
} }
CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) { CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) {
@ -66,11 +66,11 @@ namespace sd {
auto gradY = OUTPUT_VARIABLE(1); auto gradY = OUTPUT_VARIABLE(1);
gradX->assign(epsNext); gradX->assign(epsNext);
sd::ops::floormod op; NDArray temp(*epsNext);
auto tmpResult(op.evaluate({x, y})); BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp);
if (gradY->rankOf() == gradX->rankOf()) if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY); epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY);
else // epsNext is greater than gradY else // epsNext is greater than gradY
{ {
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2); std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
@ -78,7 +78,7 @@ namespace sd {
for (Nd4jLong d = 0; d < gap; d++) { for (Nd4jLong d = 0; d < gap; d++) {
dims[d * 2 + 1] = 1; dims[d * 2 + 1] = 1;
} }
auto tempIn((*tmpResult.at(0))(dims)); auto tempIn((temp)(dims));
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
} }
return Status::OK(); return Status::OK();
@ -92,8 +92,8 @@ namespace sd {
// eps always has shape of x // eps always has shape of x
// grad always has shape of y // grad always has shape of y
Nd4jLong *shapeE; Nd4jLong* shapeE;
Nd4jLong *shapeG; Nd4jLong* shapeG;
COPY_SHAPE(x, shapeE); COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG); COPY_SHAPE(y, shapeG);

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// //
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string) #if NOT_EXCLUDED(OP_split_string)
@ -60,7 +60,7 @@ namespace sd {
// filling output indices // filling output indices
for (uint64_t f = 0; f < cnt; f++) { for (uint64_t f = 0; f < cnt; f++) {
for (auto v: icoords) for (auto v : icoords)
indices->p(ic++, v); indices->p(ic++, v);
// last index // last index
@ -75,12 +75,12 @@ namespace sd {
for (auto e = 0L; e < input->lengthOf(); e++) { for (auto e = 0L; e < input->lengthOf(); e++) {
auto split = StringUtils::split(input->e<std::string>(e), d); auto split = StringUtils::split(input->e<std::string>(e), d);
for (const auto &s:split) for (const auto& s : split)
strings.emplace_back(s); strings.emplace_back(s);
} }
// now once we have all strings in single vector time to fill // now once we have all strings in single vector time to fill
auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings); auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext());
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
// for CUDA mostly // for CUDA mostly
@ -129,9 +129,9 @@ namespace sd {
DECLARE_TYPES(compat_string_split) { DECLARE_TYPES(compat_string_split) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes({ALL_STRINGS}) ->setAllowedInputTypes({ ALL_STRINGS })
->setAllowedOutputTypes(0, {ALL_INDICES}) ->setAllowedOutputTypes(0, { ALL_INDICES })
->setAllowedOutputTypes(1, {ALL_STRINGS}); ->setAllowedOutputTypes(1, { ALL_STRINGS });
} }
} }
} }

View File

@ -68,6 +68,7 @@ namespace sd {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -201,6 +202,7 @@ namespace sd {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -73,6 +73,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -216,6 +217,7 @@ DECLARE_SHAPE_FN(huber_loss) {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -70,6 +70,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -206,6 +207,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -74,6 +74,7 @@ namespace ops {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
@ -209,6 +210,7 @@ namespace ops {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -143,6 +143,7 @@ namespace sd {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
@ -282,6 +283,7 @@ namespace sd {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -67,6 +67,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
@ -200,6 +201,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -78,6 +78,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
@ -219,6 +220,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
} }
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
NDArray sum; NDArray sum;
sum.setContext(block.launchContext());
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author sgazeos@gmail.com // @author sgazeos@gmail.com
// //
#include <ops/declarable/generic/helpers/BroadcastHelper.h> #include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <ops/declarable/headers/parity_ops.h> #include <ops/declarable/headers/parity_ops.h>
@ -29,24 +29,24 @@ namespace sd {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector()); auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext());
BROADCAST_CHECK_EMPTY(x, y, (&z0)); BROADCAST_CHECK_EMPTY(x, y, (&z0));
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
bitcast res; bitcast res;
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false); auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false);
if (tZ != &z0) { if (tZ != &z0) {
delete tZ; delete tZ;
} }
return status; return status;
} }
DECLARE_TYPES(compare_and_bitpack) { DECLARE_TYPES(compare_and_bitpack) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY) ->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, DataType::ANY) ->setAllowedInputTypes(1, DataType::ANY)
->setAllowedOutputTypes(0, DataType::UINT8); ->setAllowedOutputTypes(0, DataType::UINT8);
} }
DECLARE_SHAPE_FN(compare_and_bitpack) { DECLARE_SHAPE_FN(compare_and_bitpack) {

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// Created by george@skymind.io on 26.01.2018. // Created by george@skymind.io on 26.01.2018.
// //
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_normalize_moments) #if NOT_EXCLUDED(OP_normalize_moments)
@ -34,7 +34,7 @@ namespace sd {
auto resVariances = OUTPUT_VARIABLE(1); auto resVariances = OUTPUT_VARIABLE(1);
// FIXME: double? // FIXME: double?
NDArray shift = NDArrayFactory::create<double>(0.); NDArray shift = NDArrayFactory::create<double>(0., block.launchContext());
if (block.getTArguments()->size() > 0) { if (block.getTArguments()->size() > 0) {
shift.assign(T_ARG(0)); shift.assign(T_ARG(0));
@ -47,7 +47,7 @@ namespace sd {
squareMeans.applyTransform(transform::Square, squareMeans, nullptr); squareMeans.applyTransform(transform::Square, squareMeans, nullptr);
variances->applyScalarArr(scalar::Divide, *counts, tempVariances); variances->applyScalarArr(scalar::Divide, *counts, tempVariances);
// tempVariances.printIndexedBuffer("varianced divided by count"); // tempVariances.printIndexedBuffer("varianced divided by count");
tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances); tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances);
if (shift.e<double>(0) != 0) { if (shift.e<double>(0) != 0) {
@ -75,8 +75,8 @@ namespace sd {
DECLARE_TYPES(normalize_moments) { DECLARE_TYPES(normalize_moments) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS}); ->setAllowedOutputTypes({ ALL_FLOATS });
} }
} }

View File

@ -49,8 +49,8 @@ namespace sd {
bool disposable = false; bool disposable = false;
if (min == nullptr && max == nullptr && block.numT() >= 2) { if (min == nullptr && max == nullptr && block.numT() >= 2) {
min = NDArrayFactory::create_(dtype); min = NDArrayFactory::create_(dtype, block.launchContext());
max = NDArrayFactory::create_(dtype); max = NDArrayFactory::create_(dtype, block.launchContext());
min->p(0, T_ARG(0)); min->p(0, T_ARG(0));
max->p(0, T_ARG(1)); max->p(0, T_ARG(1));
disposable = true; disposable = true;

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author George A. Shulinok <sgazeos@gmail.com>, created on 4/18/2019. // @author George A. Shulinok <sgazeos@gmail.com>, created on 4/18/2019.
// //
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_barnes_symmetrized) #if NOT_EXCLUDED(OP_barnes_symmetrized)
@ -25,20 +25,20 @@
#include <ops/declarable/helpers/BarnesHutTsne.h> #include <ops/declarable/helpers/BarnesHutTsne.h>
namespace sd { namespace sd {
namespace ops { namespace ops {
NDArray* rowCountsPtr = nullptr; NDArray* rowCountsPtr = nullptr;
CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) {
auto rowP = INPUT_VARIABLE(0); auto rowP = INPUT_VARIABLE(0);
auto colP = INPUT_VARIABLE(1); auto colP = INPUT_VARIABLE(1);
auto valP = INPUT_VARIABLE(2); auto valP = INPUT_VARIABLE(2);
auto N = rowP->lengthOf() - 1; auto N = rowP->lengthOf() - 1;
auto outputRows = OUTPUT_VARIABLE(0); auto outputRows = OUTPUT_VARIABLE(0);
auto outputCols = OUTPUT_VARIABLE(1); auto outputCols = OUTPUT_VARIABLE(1);
auto outputVals = OUTPUT_VARIABLE(2); auto outputVals = OUTPUT_VARIABLE(2);
if (block.getIArguments()->size() > 0) if (block.getIArguments()->size() > 0)
N = INT_ARG(0); N = INT_ARG(0);
if (rowCountsPtr) { if (rowCountsPtr) {
helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr); helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr);
@ -46,33 +46,33 @@ namespace ops {
return Status::OK(); return Status::OK();
} }
return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data."); return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data.");
} }
DECLARE_TYPES(barnes_symmetrized) { DECLARE_TYPES(barnes_symmetrized) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, {DataType::INT32}) ->setAllowedInputTypes(0, { DataType::INT32 })
->setAllowedInputTypes(1, {DataType::INT32}) ->setAllowedInputTypes(1, { DataType::INT32 })
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) ->setAllowedInputTypes(2, { ALL_INTS, ALL_FLOATS })
->setAllowedOutputTypes(1, {DataType::INT32}) ->setAllowedOutputTypes(1, { DataType::INT32 })
->setAllowedOutputTypes(1, {DataType::INT32}) ->setAllowedOutputTypes(1, { DataType::INT32 })
->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS}) ->setAllowedOutputTypes(2, { ALL_INTS, ALL_FLOATS })
->setSameMode(false); ->setSameMode(false);
} }
DECLARE_SHAPE_FN(barnes_symmetrized) { DECLARE_SHAPE_FN(barnes_symmetrized) {
auto valPShapeInfo = inputShape->at(2); auto valPShapeInfo = inputShape->at(2);
Nd4jLong* outShapeInfo; Nd4jLong* outShapeInfo;
auto rowP = INPUT_VARIABLE(0); auto rowP = INPUT_VARIABLE(0);
auto colP = INPUT_VARIABLE(1); auto colP = INPUT_VARIABLE(1);
auto N = rowP->lengthOf() - 1; auto N = rowP->lengthOf() - 1;
if (block.getIArguments()->size() > 0) if (block.getIArguments()->size() > 0)
N = INT_ARG(0); N = INT_ARG(0);
auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0)); auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0));
NDArray* rowCounts = NDArrayFactory::create_<int>('c', {N}); //rowP->dup(); NDArray* rowCounts = NDArrayFactory::create_<int>('c', { N }, block.launchContext()); //rowP->dup();
//srowCounts->assign(0); //srowCounts->assign(0);
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
rowCounts->syncToHost(); rowCounts->syncToHost();
// rowCounts->printBuffer("Row Counts"); // rowCounts->printBuffer("Row Counts");
if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len."); if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len.");
rowCountsPtr = rowCounts; rowCountsPtr = rowCounts;
//ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); //ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong);
@ -80,13 +80,13 @@ namespace ops {
// outShapeInfo[2] = len; // outShapeInfo[2] = len;
// ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c'); // ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c');
//outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace()); //outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace());
outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', {1, len}, block.getWorkspace()); outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.getWorkspace());
auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, len}, block.getWorkspace()); auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.getWorkspace());
auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, N + 1}, block.getWorkspace()); auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.getWorkspace());
return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo));
} }
} }
} }
#endif #endif

View File

@ -142,7 +142,7 @@ namespace helpers {
const int rowNum = input->rows(); const int rowNum = input->rows();
const int columnNum = input->columns(); const int columnNum = input->columns();
NDArray determinant = NDArrayFactory::create<T>(1.f); NDArray determinant = NDArrayFactory::create<T>(1.f, context);
NDArray compoundMatrix = *input; // copy NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
permutationMatrix.setIdentity(); permutationMatrix.setIdentity();

View File

@ -39,7 +39,7 @@ namespace helpers {
template <typename T> template <typename T>
NDArray vmul(NDArray const& v, int n) NDArray vmul(NDArray const& v, int n)
{ {
NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n); NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n);
T const* vBuf = v.getDataBuffer()->primaryAsT<T>(); T const* vBuf = v.getDataBuffer()->primaryAsT<T>();
T* resBuf = res.dataBuffer()->primaryAsT<T>(); T* resBuf = res.dataBuffer()->primaryAsT<T>();
auto interloop = PRAGMA_THREADS_FOR_2D { auto interloop = PRAGMA_THREADS_FOR_2D {
@ -61,7 +61,7 @@ namespace helpers {
std::vector<NDArray> q(M); std::vector<NDArray> q(M);
NDArray z = *matrix; NDArray z = *matrix;
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm NDArray e('c', {M}, DataTypeUtils::fromT<T>(), Q->getContext()); // two internal buffers and scalar for squared norm
for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
e.nullify(); e.nullify();

View File

@ -69,9 +69,9 @@ namespace helpers {
auto trial = (*input)(e, dimsToExclude); auto trial = (*input)(e, dimsToExclude);
// fill up the first k elements // fill up the first k elements
NDArray topValues = NDArrayFactory::create<T>('c', {k}); NDArray topValues = NDArrayFactory::create<T>('c', {k}, input->getContext());
NDArray sortedVals = NDArrayFactory::create<T>('c', {k}); NDArray sortedVals = NDArrayFactory::create<T>('c', {k}, input->getContext());
NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k}); NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k}, input->getContext());
for (uint pos = 0; pos < k; ++pos) { for (uint pos = 0; pos < k; ++pos) {
topIndices.t<Nd4jLong>(pos) = pos; topIndices.t<Nd4jLong>(pos) = pos;
topValues.t<T>(pos) = trial.t<T>(pos); topValues.t<T>(pos) = trial.t<T>(pos);
@ -144,7 +144,7 @@ namespace helpers {
for (int i = 0; i < input->rankOf() - 1; i++) for (int i = 0; i < input->rankOf() - 1; i++)
shapeI[i] = input->sizeAt(i); shapeI[i] = input->sizeAt(i);
shapeI[input->rankOf() - 1] = k; shapeI[input->rankOf() - 1] = k;
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI)); std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI, context));
NDArray* values = nullptr; NDArray* values = nullptr;
int status = topKFunctor(context, input, values, indices.get(), k, true); int status = topKFunctor(context, input, values, indices.get(), k, true);
result->assign(0); result->assign(0);

View File

@ -112,7 +112,7 @@ namespace sd {
int numThreads = 256; int numThreads = 256;
int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads)); int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
int workspaceSize = numBlocks * numBins; int workspaceSize = numBlocks * numBins;
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize}); auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize}, context);
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val)); histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));

View File

@ -25,7 +25,7 @@ namespace ops {
namespace helpers { namespace helpers {
typedef NDArray ColorTable_t; typedef NDArray ColorTable_t;
static NDArray DefaultColorTable(int depth) { static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) {
//std::vector<std::vector<float>> colorTable; //std::vector<std::vector<float>> colorTable;
const Nd4jLong kDefaultTableLength = 10; const Nd4jLong kDefaultTableLength = 10;
const Nd4jLong kDefaultChannelLength = 4; const Nd4jLong kDefaultChannelLength = 4;
@ -40,7 +40,7 @@ namespace helpers {
0, 0, 0.5, 1, // 7: navy blue 0, 0, 0.5, 1, // 7: navy blue
0, 1, 1, 1, // 8: aqua 0, 1, 1, 1, // 8: aqua
1, 0, 1, 1 // 9: fuchsia 1, 0, 1, 1 // 9: fuchsia
}, DataType::FLOAT32); }, DataType::FLOAT32, context);
if (depth == 1) { if (depth == 1) {
colorTable.assign(1.f); // all to white when black and white colors colorTable.assign(1.f); // all to white when black and white colors
@ -144,7 +144,7 @@ namespace helpers {
auto channels = images->sizeAt(3); auto channels = images->sizeAt(3);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
auto boxSize = boxes->sizeAt(1); auto boxSize = boxes->sizeAt(1);
NDArray colorsTable = DefaultColorTable(channels); NDArray colorsTable = DefaultColorTable(channels, context);
if ((colors != nullptr && colors->lengthOf() > 0)) { if ((colors != nullptr && colors->lengthOf() > 0)) {
colorsTable = *colors; colorsTable = *colors;
} }

View File

@ -188,7 +188,7 @@ namespace helpers {
static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {boxes, scales}); NDArray::prepareSpecialUse({output}, {boxes, scales});
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext()); std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext());
NDArray scores(*scales); NDArray scores(*scales);
Nd4jPointer extras[2] = {nullptr, stream}; Nd4jPointer extras[2] = {nullptr, stream};
@ -198,7 +198,7 @@ namespace helpers {
indices->tickWriteDevice(); indices->tickWriteDevice();
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
indices->tickWriteDevice(); indices->tickWriteDevice();
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()}); NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()}, context);
int numSelected = 0; int numSelected = 0;
int numBoxes = boxes->sizeAt(0); int numBoxes = boxes->sizeAt(0);
auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer()); auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
@ -347,8 +347,8 @@ namespace helpers {
scores->syncToDevice(); scores->syncToDevice();
} }
NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext());
NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()}); NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context);
NDArray selectedScores(*scores); NDArray selectedScores(*scores);
Nd4jPointer extras[2] = {nullptr, stream}; Nd4jPointer extras[2] = {nullptr, stream};
auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer()); auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer());

View File

@ -598,7 +598,7 @@ namespace helpers {
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // <int>('c', {n}); NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // <int>('c', {n});
iota.linspace(0); iota.syncToDevice(); iota.linspace(0); iota.syncToDevice();
output->assign(input); // fill up output tensor with zeros output->assign(input); // fill up output tensor with zeros
@ -631,7 +631,7 @@ namespace helpers {
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1, context);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
@ -677,7 +677,7 @@ namespace helpers {
dtype = DataType::FLOAT32; dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1, context);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);

View File

@ -110,7 +110,7 @@ namespace helpers {
auto resR = fullMatricies?R->ulike():matrix->ulike(); auto resR = fullMatricies?R->ulike():matrix->ulike();
std::vector<NDArray> q(M); std::vector<NDArray> q(M);
NDArray z = *matrix; NDArray z = *matrix;
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm NDArray e('c', {M}, DataTypeUtils::fromT<T>(), context); // two internal buffers and scalar for squared norm
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
e.nullify(); e.nullify();
z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix) z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix)
@ -177,4 +177,3 @@ namespace helpers {
} }
} }
} }

View File

@ -167,8 +167,8 @@ namespace sd {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
indices->syncToHost(); indices->syncToHost();
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -209,8 +209,8 @@ namespace sd {
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
output->assign(DataTypeUtils::infOrMax<T>()); output->assign(DataTypeUtils::infOrMax<T>());
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());

View File

@ -158,8 +158,8 @@ namespace helpers {
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -198,8 +198,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
@ -314,8 +314,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -367,8 +367,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);

View File

@ -161,8 +161,8 @@ namespace helpers {
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
output->assign(DataTypeUtils::infOrMax<T>()); output->assign(DataTypeUtils::infOrMax<T>());
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -202,8 +202,8 @@ namespace helpers {
static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
output->assign(DataTypeUtils::infOrMax<T>()); output->assign(DataTypeUtils::infOrMax<T>());

View File

@ -122,8 +122,8 @@ namespace helpers {
static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
output->assign(1); output->assign(1);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -160,8 +160,8 @@ namespace helpers {
static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());

View File

@ -86,8 +86,8 @@ namespace helpers {
static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
@ -207,8 +207,8 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);

View File

@ -162,8 +162,8 @@ namespace helpers {
static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
@ -201,8 +201,8 @@ namespace helpers {
static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());

View File

@ -41,7 +41,7 @@ namespace helpers {
length += array->lengthOf(); length += array->lengthOf();
pos++; pos++;
} }
NDArray arrayFull('c', {length}, sd::DataType::INT32); NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext());
cContext.setOutputArray(0, &arrayFull); cContext.setOutputArray(0, &arrayFull);
cContext.setIArguments(&axis, 1); cContext.setIArguments(&axis, 1);