libnd4j fixes for context sync in operation execution (#350)
Signed-off-by: Oleg <oleg.semeniv@gmail.com>master
parent
9b3576bc00
commit
bf0ddbc06c
|
@ -23,6 +23,7 @@
|
||||||
#include <array/NDArrayList.h>
|
#include <array/NDArrayList.h>
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include<ops/declarable/helpers/stack.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
NDArrayList::NDArrayList(int height, bool expandable) {
|
NDArrayList::NDArrayList(int height, bool expandable) {
|
||||||
|
@ -144,25 +145,38 @@ namespace sd {
|
||||||
|
|
||||||
NDArray* NDArrayList::stack() {
|
NDArray* NDArrayList::stack() {
|
||||||
// FIXME: this is bad for perf, but ok as poc
|
// FIXME: this is bad for perf, but ok as poc
|
||||||
sd::ops::stack op;
|
|
||||||
std::vector<NDArray*> inputs;
|
|
||||||
std::vector<double> targs;
|
|
||||||
std::vector<Nd4jLong> iargs({0});
|
|
||||||
std::vector<bool> bargs;
|
|
||||||
int numElements = _elements.load();
|
|
||||||
|
|
||||||
|
int numElements = _elements.load();
|
||||||
|
std::vector<const NDArray*> inputs(numElements);
|
||||||
for (int e = 0; e < numElements; e++) {
|
for (int e = 0; e < numElements; e++) {
|
||||||
_chunks[e]->syncToDevice();
|
_chunks[e]->syncToDevice();
|
||||||
inputs.emplace_back(_chunks[e]);
|
inputs[e] = _chunks[e];
|
||||||
}
|
}
|
||||||
|
|
||||||
iargs.push_back(_axis);
|
auto inShapeInfo = inputs[0]->getShapeInfo();
|
||||||
|
int rank = shape::rank(inShapeInfo);
|
||||||
|
NDArray* array = nullptr;
|
||||||
|
|
||||||
auto result = op.evaluate(inputs);
|
if (shape::isEmpty(inShapeInfo)) {
|
||||||
|
switch (rank) {
|
||||||
|
case 0: {
|
||||||
|
if (numElements == 1) {
|
||||||
|
array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
|
||||||
|
} else {
|
||||||
|
array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
std::vector<Nd4jLong> outShape(inShapeInfo + 1, inShapeInfo + 1 + rank);
|
||||||
|
outShape.insert(outShape.begin(), (Nd4jLong) numElements);
|
||||||
|
array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
auto array = new NDArray(result.at(0)->dup());
|
ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0);
|
||||||
|
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<int,int>& NDArrayList::id() {
|
std::pair<int,int>& NDArrayList::id() {
|
||||||
|
|
|
@ -14,10 +14,10 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
// modified by sgazeos@gmail.com with backprop implementation.
|
// modified by sgazeos@gmail.com with backprop implementation.
|
||||||
//
|
//
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_floormod)
|
#if NOT_EXCLUDED(OP_floormod)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ namespace sd {
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
BROADCAST_CHECK_EMPTY(x, y, z);
|
||||||
|
|
||||||
REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!");
|
REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!");
|
||||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z);
|
auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z);
|
||||||
|
@ -46,15 +46,15 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_TYPES(floormod) {
|
DECLARE_TYPES(floormod) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, DataType::ANY)
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
->setAllowedInputTypes(1, DataType::ANY)
|
->setAllowedInputTypes(1, DataType::ANY)
|
||||||
->setAllowedOutputTypes(0, DataType::INHERIT);
|
->setAllowedOutputTypes(0, DataType::INHERIT);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(floormod_bp) {
|
DECLARE_TYPES(floormod_bp) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(DataType::ANY)
|
->setAllowedInputTypes(DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) {
|
CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) {
|
||||||
|
@ -66,11 +66,11 @@ namespace sd {
|
||||||
auto gradY = OUTPUT_VARIABLE(1);
|
auto gradY = OUTPUT_VARIABLE(1);
|
||||||
gradX->assign(epsNext);
|
gradX->assign(epsNext);
|
||||||
|
|
||||||
sd::ops::floormod op;
|
NDArray temp(*epsNext);
|
||||||
auto tmpResult(op.evaluate({x, y}));
|
BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp);
|
||||||
|
|
||||||
if (gradY->rankOf() == gradX->rankOf())
|
if (gradY->rankOf() == gradX->rankOf())
|
||||||
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY);
|
epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY);
|
||||||
else // epsNext is greater than gradY
|
else // epsNext is greater than gradY
|
||||||
{
|
{
|
||||||
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
|
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
|
||||||
|
@ -78,7 +78,7 @@ namespace sd {
|
||||||
for (Nd4jLong d = 0; d < gap; d++) {
|
for (Nd4jLong d = 0; d < gap; d++) {
|
||||||
dims[d * 2 + 1] = 1;
|
dims[d * 2 + 1] = 1;
|
||||||
}
|
}
|
||||||
auto tempIn((*tmpResult.at(0))(dims));
|
auto tempIn((temp)(dims));
|
||||||
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
|
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -92,8 +92,8 @@ namespace sd {
|
||||||
// eps always has shape of x
|
// eps always has shape of x
|
||||||
// grad always has shape of y
|
// grad always has shape of y
|
||||||
|
|
||||||
Nd4jLong *shapeE;
|
Nd4jLong* shapeE;
|
||||||
Nd4jLong *shapeG;
|
Nd4jLong* shapeG;
|
||||||
|
|
||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_split_string)
|
#if NOT_EXCLUDED(OP_split_string)
|
||||||
|
@ -60,7 +60,7 @@ namespace sd {
|
||||||
|
|
||||||
// filling output indices
|
// filling output indices
|
||||||
for (uint64_t f = 0; f < cnt; f++) {
|
for (uint64_t f = 0; f < cnt; f++) {
|
||||||
for (auto v: icoords)
|
for (auto v : icoords)
|
||||||
indices->p(ic++, v);
|
indices->p(ic++, v);
|
||||||
|
|
||||||
// last index
|
// last index
|
||||||
|
@ -75,12 +75,12 @@ namespace sd {
|
||||||
for (auto e = 0L; e < input->lengthOf(); e++) {
|
for (auto e = 0L; e < input->lengthOf(); e++) {
|
||||||
auto split = StringUtils::split(input->e<std::string>(e), d);
|
auto split = StringUtils::split(input->e<std::string>(e), d);
|
||||||
|
|
||||||
for (const auto &s:split)
|
for (const auto& s : split)
|
||||||
strings.emplace_back(s);
|
strings.emplace_back(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// now once we have all strings in single vector time to fill
|
// now once we have all strings in single vector time to fill
|
||||||
auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings);
|
auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext());
|
||||||
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
||||||
|
|
||||||
// for CUDA mostly
|
// for CUDA mostly
|
||||||
|
@ -129,9 +129,9 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_TYPES(compat_string_split) {
|
DECLARE_TYPES(compat_string_split) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes({ALL_STRINGS})
|
->setAllowedInputTypes({ ALL_STRINGS })
|
||||||
->setAllowedOutputTypes(0, {ALL_INDICES})
|
->setAllowedOutputTypes(0, { ALL_INDICES })
|
||||||
->setAllowedOutputTypes(1, {ALL_STRINGS});
|
->setAllowedOutputTypes(1, { ALL_STRINGS });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,6 +68,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -201,6 +202,7 @@ namespace sd {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -73,6 +73,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -216,6 +217,7 @@ DECLARE_SHAPE_FN(huber_loss) {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -70,6 +70,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -206,6 +207,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -74,6 +74,7 @@ namespace ops {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = *weights * E.lengthOf();
|
sum = *weights * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -209,6 +210,7 @@ namespace ops {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -143,6 +143,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -282,6 +283,7 @@ namespace sd {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -67,6 +67,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -200,6 +201,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
|
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -78,6 +78,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
@ -219,6 +220,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
|
||||||
}
|
}
|
||||||
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array
|
||||||
NDArray sum;
|
NDArray sum;
|
||||||
|
sum.setContext(block.launchContext());
|
||||||
if (weights->isScalar())
|
if (weights->isScalar())
|
||||||
sum = (*weights) * E.lengthOf();
|
sum = (*weights) * E.lengthOf();
|
||||||
else
|
else
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author sgazeos@gmail.com
|
// @author sgazeos@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||||
#include <ops/declarable/headers/parity_ops.h>
|
#include <ops/declarable/headers/parity_ops.h>
|
||||||
|
@ -29,12 +29,12 @@ namespace sd {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector());
|
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector(), block.launchContext());
|
||||||
BROADCAST_CHECK_EMPTY(x, y, (&z0));
|
BROADCAST_CHECK_EMPTY(x, y, (&z0));
|
||||||
|
|
||||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
||||||
bitcast res;
|
bitcast res;
|
||||||
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false);
|
auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false);
|
||||||
if (tZ != &z0) {
|
if (tZ != &z0) {
|
||||||
delete tZ;
|
delete tZ;
|
||||||
}
|
}
|
||||||
|
@ -44,9 +44,9 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_TYPES(compare_and_bitpack) {
|
DECLARE_TYPES(compare_and_bitpack) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, DataType::ANY)
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
->setAllowedInputTypes(1, DataType::ANY)
|
->setAllowedInputTypes(1, DataType::ANY)
|
||||||
->setAllowedOutputTypes(0, DataType::UINT8);
|
->setAllowedOutputTypes(0, DataType::UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(compare_and_bitpack) {
|
DECLARE_SHAPE_FN(compare_and_bitpack) {
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by george@skymind.io on 26.01.2018.
|
// Created by george@skymind.io on 26.01.2018.
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_normalize_moments)
|
#if NOT_EXCLUDED(OP_normalize_moments)
|
||||||
|
@ -34,7 +34,7 @@ namespace sd {
|
||||||
auto resVariances = OUTPUT_VARIABLE(1);
|
auto resVariances = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
// FIXME: double?
|
// FIXME: double?
|
||||||
NDArray shift = NDArrayFactory::create<double>(0.);
|
NDArray shift = NDArrayFactory::create<double>(0., block.launchContext());
|
||||||
|
|
||||||
if (block.getTArguments()->size() > 0) {
|
if (block.getTArguments()->size() > 0) {
|
||||||
shift.assign(T_ARG(0));
|
shift.assign(T_ARG(0));
|
||||||
|
@ -47,7 +47,7 @@ namespace sd {
|
||||||
|
|
||||||
squareMeans.applyTransform(transform::Square, squareMeans, nullptr);
|
squareMeans.applyTransform(transform::Square, squareMeans, nullptr);
|
||||||
variances->applyScalarArr(scalar::Divide, *counts, tempVariances);
|
variances->applyScalarArr(scalar::Divide, *counts, tempVariances);
|
||||||
// tempVariances.printIndexedBuffer("varianced divided by count");
|
// tempVariances.printIndexedBuffer("varianced divided by count");
|
||||||
tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances);
|
tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances);
|
||||||
|
|
||||||
if (shift.e<double>(0) != 0) {
|
if (shift.e<double>(0) != 0) {
|
||||||
|
@ -75,8 +75,8 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_TYPES(normalize_moments) {
|
DECLARE_TYPES(normalize_moments) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,8 +49,8 @@ namespace sd {
|
||||||
bool disposable = false;
|
bool disposable = false;
|
||||||
|
|
||||||
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
if (min == nullptr && max == nullptr && block.numT() >= 2) {
|
||||||
min = NDArrayFactory::create_(dtype);
|
min = NDArrayFactory::create_(dtype, block.launchContext());
|
||||||
max = NDArrayFactory::create_(dtype);
|
max = NDArrayFactory::create_(dtype, block.launchContext());
|
||||||
min->p(0, T_ARG(0));
|
min->p(0, T_ARG(0));
|
||||||
max->p(0, T_ARG(1));
|
max->p(0, T_ARG(1));
|
||||||
disposable = true;
|
disposable = true;
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author George A. Shulinok <sgazeos@gmail.com>, created on 4/18/2019.
|
// @author George A. Shulinok <sgazeos@gmail.com>, created on 4/18/2019.
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_barnes_symmetrized)
|
#if NOT_EXCLUDED(OP_barnes_symmetrized)
|
||||||
|
@ -25,20 +25,20 @@
|
||||||
#include <ops/declarable/helpers/BarnesHutTsne.h>
|
#include <ops/declarable/helpers/BarnesHutTsne.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
NDArray* rowCountsPtr = nullptr;
|
NDArray* rowCountsPtr = nullptr;
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) {
|
CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) {
|
||||||
auto rowP = INPUT_VARIABLE(0);
|
auto rowP = INPUT_VARIABLE(0);
|
||||||
auto colP = INPUT_VARIABLE(1);
|
auto colP = INPUT_VARIABLE(1);
|
||||||
auto valP = INPUT_VARIABLE(2);
|
auto valP = INPUT_VARIABLE(2);
|
||||||
auto N = rowP->lengthOf() - 1;
|
auto N = rowP->lengthOf() - 1;
|
||||||
auto outputRows = OUTPUT_VARIABLE(0);
|
auto outputRows = OUTPUT_VARIABLE(0);
|
||||||
auto outputCols = OUTPUT_VARIABLE(1);
|
auto outputCols = OUTPUT_VARIABLE(1);
|
||||||
auto outputVals = OUTPUT_VARIABLE(2);
|
auto outputVals = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
if (block.getIArguments()->size() > 0)
|
if (block.getIArguments()->size() > 0)
|
||||||
N = INT_ARG(0);
|
N = INT_ARG(0);
|
||||||
|
|
||||||
if (rowCountsPtr) {
|
if (rowCountsPtr) {
|
||||||
helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr);
|
helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr);
|
||||||
|
@ -46,33 +46,33 @@ namespace ops {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data.");
|
return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data.");
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(barnes_symmetrized) {
|
DECLARE_TYPES(barnes_symmetrized) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, {DataType::INT32})
|
->setAllowedInputTypes(0, { DataType::INT32 })
|
||||||
->setAllowedInputTypes(1, {DataType::INT32})
|
->setAllowedInputTypes(1, { DataType::INT32 })
|
||||||
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
|
->setAllowedInputTypes(2, { ALL_INTS, ALL_FLOATS })
|
||||||
->setAllowedOutputTypes(1, {DataType::INT32})
|
->setAllowedOutputTypes(1, { DataType::INT32 })
|
||||||
->setAllowedOutputTypes(1, {DataType::INT32})
|
->setAllowedOutputTypes(1, { DataType::INT32 })
|
||||||
->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS})
|
->setAllowedOutputTypes(2, { ALL_INTS, ALL_FLOATS })
|
||||||
->setSameMode(false);
|
->setSameMode(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(barnes_symmetrized) {
|
DECLARE_SHAPE_FN(barnes_symmetrized) {
|
||||||
auto valPShapeInfo = inputShape->at(2);
|
auto valPShapeInfo = inputShape->at(2);
|
||||||
Nd4jLong* outShapeInfo;
|
Nd4jLong* outShapeInfo;
|
||||||
auto rowP = INPUT_VARIABLE(0);
|
auto rowP = INPUT_VARIABLE(0);
|
||||||
auto colP = INPUT_VARIABLE(1);
|
auto colP = INPUT_VARIABLE(1);
|
||||||
auto N = rowP->lengthOf() - 1;
|
auto N = rowP->lengthOf() - 1;
|
||||||
if (block.getIArguments()->size() > 0)
|
if (block.getIArguments()->size() > 0)
|
||||||
N = INT_ARG(0);
|
N = INT_ARG(0);
|
||||||
auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0));
|
auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0));
|
||||||
NDArray* rowCounts = NDArrayFactory::create_<int>('c', {N}); //rowP->dup();
|
NDArray* rowCounts = NDArrayFactory::create_<int>('c', { N }, block.launchContext()); //rowP->dup();
|
||||||
//srowCounts->assign(0);
|
//srowCounts->assign(0);
|
||||||
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
|
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
|
||||||
rowCounts->syncToHost();
|
rowCounts->syncToHost();
|
||||||
// rowCounts->printBuffer("Row Counts");
|
// rowCounts->printBuffer("Row Counts");
|
||||||
if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len.");
|
if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len.");
|
||||||
rowCountsPtr = rowCounts;
|
rowCountsPtr = rowCounts;
|
||||||
//ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong);
|
//ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong);
|
||||||
|
@ -80,13 +80,13 @@ namespace ops {
|
||||||
// outShapeInfo[2] = len;
|
// outShapeInfo[2] = len;
|
||||||
// ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c');
|
// ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c');
|
||||||
//outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace());
|
//outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace());
|
||||||
outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', {1, len}, block.getWorkspace());
|
outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.getWorkspace());
|
||||||
auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, len}, block.getWorkspace());
|
auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.getWorkspace());
|
||||||
auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', {1, N + 1}, block.getWorkspace());
|
auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.getWorkspace());
|
||||||
return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo));
|
return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -142,7 +142,7 @@ namespace helpers {
|
||||||
const int rowNum = input->rows();
|
const int rowNum = input->rows();
|
||||||
const int columnNum = input->columns();
|
const int columnNum = input->columns();
|
||||||
|
|
||||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
NDArray determinant = NDArrayFactory::create<T>(1.f, context);
|
||||||
NDArray compoundMatrix = *input; // copy
|
NDArray compoundMatrix = *input; // copy
|
||||||
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
||||||
permutationMatrix.setIdentity();
|
permutationMatrix.setIdentity();
|
||||||
|
|
|
@ -39,7 +39,7 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NDArray vmul(NDArray const& v, int n)
|
NDArray vmul(NDArray const& v, int n)
|
||||||
{
|
{
|
||||||
NDArray res('c', {n,n}, v.dataType()); // x = matrix_new(n, n);
|
NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n);
|
||||||
T const* vBuf = v.getDataBuffer()->primaryAsT<T>();
|
T const* vBuf = v.getDataBuffer()->primaryAsT<T>();
|
||||||
T* resBuf = res.dataBuffer()->primaryAsT<T>();
|
T* resBuf = res.dataBuffer()->primaryAsT<T>();
|
||||||
auto interloop = PRAGMA_THREADS_FOR_2D {
|
auto interloop = PRAGMA_THREADS_FOR_2D {
|
||||||
|
@ -61,7 +61,7 @@ namespace helpers {
|
||||||
std::vector<NDArray> q(M);
|
std::vector<NDArray> q(M);
|
||||||
|
|
||||||
NDArray z = *matrix;
|
NDArray z = *matrix;
|
||||||
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm
|
NDArray e('c', {M}, DataTypeUtils::fromT<T>(), Q->getContext()); // two internal buffers and scalar for squared norm
|
||||||
|
|
||||||
for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
||||||
e.nullify();
|
e.nullify();
|
||||||
|
|
|
@ -69,9 +69,9 @@ namespace helpers {
|
||||||
auto trial = (*input)(e, dimsToExclude);
|
auto trial = (*input)(e, dimsToExclude);
|
||||||
|
|
||||||
// fill up the first k elements
|
// fill up the first k elements
|
||||||
NDArray topValues = NDArrayFactory::create<T>('c', {k});
|
NDArray topValues = NDArrayFactory::create<T>('c', {k}, input->getContext());
|
||||||
NDArray sortedVals = NDArrayFactory::create<T>('c', {k});
|
NDArray sortedVals = NDArrayFactory::create<T>('c', {k}, input->getContext());
|
||||||
NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k});
|
NDArray topIndices = NDArrayFactory::create<Nd4jLong>('c', {k}, input->getContext());
|
||||||
for (uint pos = 0; pos < k; ++pos) {
|
for (uint pos = 0; pos < k; ++pos) {
|
||||||
topIndices.t<Nd4jLong>(pos) = pos;
|
topIndices.t<Nd4jLong>(pos) = pos;
|
||||||
topValues.t<T>(pos) = trial.t<T>(pos);
|
topValues.t<T>(pos) = trial.t<T>(pos);
|
||||||
|
@ -144,7 +144,7 @@ namespace helpers {
|
||||||
for (int i = 0; i < input->rankOf() - 1; i++)
|
for (int i = 0; i < input->rankOf() - 1; i++)
|
||||||
shapeI[i] = input->sizeAt(i);
|
shapeI[i] = input->sizeAt(i);
|
||||||
shapeI[input->rankOf() - 1] = k;
|
shapeI[input->rankOf() - 1] = k;
|
||||||
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI));
|
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<Nd4jLong>(input->ordering(), shapeI, context));
|
||||||
NDArray* values = nullptr;
|
NDArray* values = nullptr;
|
||||||
int status = topKFunctor(context, input, values, indices.get(), k, true);
|
int status = topKFunctor(context, input, values, indices.get(), k, true);
|
||||||
result->assign(0);
|
result->assign(0);
|
||||||
|
|
|
@ -112,7 +112,7 @@ namespace sd {
|
||||||
int numThreads = 256;
|
int numThreads = 256;
|
||||||
int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
|
int numBlocks = sd::math::nd4j_max<int>(256, sd::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
|
||||||
int workspaceSize = numBlocks * numBins;
|
int workspaceSize = numBlocks * numBins;
|
||||||
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize});
|
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize}, context);
|
||||||
|
|
||||||
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));
|
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
typedef NDArray ColorTable_t;
|
typedef NDArray ColorTable_t;
|
||||||
static NDArray DefaultColorTable(int depth) {
|
static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) {
|
||||||
//std::vector<std::vector<float>> colorTable;
|
//std::vector<std::vector<float>> colorTable;
|
||||||
const Nd4jLong kDefaultTableLength = 10;
|
const Nd4jLong kDefaultTableLength = 10;
|
||||||
const Nd4jLong kDefaultChannelLength = 4;
|
const Nd4jLong kDefaultChannelLength = 4;
|
||||||
|
@ -40,7 +40,7 @@ namespace helpers {
|
||||||
0, 0, 0.5, 1, // 7: navy blue
|
0, 0, 0.5, 1, // 7: navy blue
|
||||||
0, 1, 1, 1, // 8: aqua
|
0, 1, 1, 1, // 8: aqua
|
||||||
1, 0, 1, 1 // 9: fuchsia
|
1, 0, 1, 1 // 9: fuchsia
|
||||||
}, DataType::FLOAT32);
|
}, DataType::FLOAT32, context);
|
||||||
|
|
||||||
if (depth == 1) {
|
if (depth == 1) {
|
||||||
colorTable.assign(1.f); // all to white when black and white colors
|
colorTable.assign(1.f); // all to white when black and white colors
|
||||||
|
@ -144,7 +144,7 @@ namespace helpers {
|
||||||
auto channels = images->sizeAt(3);
|
auto channels = images->sizeAt(3);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
auto boxSize = boxes->sizeAt(1);
|
auto boxSize = boxes->sizeAt(1);
|
||||||
NDArray colorsTable = DefaultColorTable(channels);
|
NDArray colorsTable = DefaultColorTable(channels, context);
|
||||||
if ((colors != nullptr && colors->lengthOf() > 0)) {
|
if ((colors != nullptr && colors->lengthOf() > 0)) {
|
||||||
colorsTable = *colors;
|
colorsTable = *colors;
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,7 +188,7 @@ namespace helpers {
|
||||||
static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
|
static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {boxes, scales});
|
NDArray::prepareSpecialUse({output}, {boxes, scales});
|
||||||
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext());
|
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext());
|
||||||
|
|
||||||
NDArray scores(*scales);
|
NDArray scores(*scales);
|
||||||
Nd4jPointer extras[2] = {nullptr, stream};
|
Nd4jPointer extras[2] = {nullptr, stream};
|
||||||
|
@ -198,7 +198,7 @@ namespace helpers {
|
||||||
indices->tickWriteDevice();
|
indices->tickWriteDevice();
|
||||||
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||||
indices->tickWriteDevice();
|
indices->tickWriteDevice();
|
||||||
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()});
|
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()}, context);
|
||||||
int numSelected = 0;
|
int numSelected = 0;
|
||||||
int numBoxes = boxes->sizeAt(0);
|
int numBoxes = boxes->sizeAt(0);
|
||||||
auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
|
auto boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
|
||||||
|
@ -347,8 +347,8 @@ namespace helpers {
|
||||||
scores->syncToDevice();
|
scores->syncToDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext());
|
NDArray indices = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext());
|
||||||
NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()});
|
NDArray startPositions = NDArrayFactory::create<I>('c', {scores->lengthOf()}, context);
|
||||||
NDArray selectedScores(*scores);
|
NDArray selectedScores(*scores);
|
||||||
Nd4jPointer extras[2] = {nullptr, stream};
|
Nd4jPointer extras[2] = {nullptr, stream};
|
||||||
auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer());
|
auto indexBuf = indices.dataBuffer()->specialAsT<I>();///reinterpret_cast<I*>(indices->specialBuffer());
|
||||||
|
|
|
@ -598,7 +598,7 @@ namespace helpers {
|
||||||
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // <int>('c', {n});
|
NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // <int>('c', {n});
|
||||||
iota.linspace(0); iota.syncToDevice();
|
iota.linspace(0); iota.syncToDevice();
|
||||||
|
|
||||||
output->assign(input); // fill up output tensor with zeros
|
output->assign(input); // fill up output tensor with zeros
|
||||||
|
@ -631,7 +631,7 @@ namespace helpers {
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1, context);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
@ -677,7 +677,7 @@ namespace helpers {
|
||||||
dtype = DataType::FLOAT32;
|
dtype = DataType::FLOAT32;
|
||||||
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1, context);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
|
|
@ -110,7 +110,7 @@ namespace helpers {
|
||||||
auto resR = fullMatricies?R->ulike():matrix->ulike();
|
auto resR = fullMatricies?R->ulike():matrix->ulike();
|
||||||
std::vector<NDArray> q(M);
|
std::vector<NDArray> q(M);
|
||||||
NDArray z = *matrix;
|
NDArray z = *matrix;
|
||||||
NDArray e('c', {M}, DataTypeUtils::fromT<T>()); // two internal buffers and scalar for squared norm
|
NDArray e('c', {M}, DataTypeUtils::fromT<T>(), context); // two internal buffers and scalar for squared norm
|
||||||
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number
|
||||||
e.nullify();
|
e.nullify();
|
||||||
z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix)
|
z = matrixMinor<T>(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix)
|
||||||
|
@ -177,4 +177,3 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -167,8 +167,8 @@ namespace sd {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
indices->syncToHost();
|
indices->syncToHost();
|
||||||
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -209,8 +209,8 @@ namespace sd {
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
output->assign(DataTypeUtils::infOrMax<T>());
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
|
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
|
|
@ -158,8 +158,8 @@ namespace helpers {
|
||||||
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -198,8 +198,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
@ -314,8 +314,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -367,8 +367,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
|
|
@ -161,8 +161,8 @@ namespace helpers {
|
||||||
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
output->assign(DataTypeUtils::infOrMax<T>());
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -202,8 +202,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
output->assign(DataTypeUtils::infOrMax<T>());
|
output->assign(DataTypeUtils::infOrMax<T>());
|
||||||
|
|
|
@ -122,8 +122,8 @@ namespace helpers {
|
||||||
static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
output->assign(1);
|
output->assign(1);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -160,8 +160,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
|
|
@ -86,8 +86,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
@ -207,8 +207,8 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
|
|
@ -162,8 +162,8 @@ namespace helpers {
|
||||||
static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}, context);
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
@ -201,8 +201,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}, context);
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
|
|
@ -41,7 +41,7 @@ namespace helpers {
|
||||||
length += array->lengthOf();
|
length += array->lengthOf();
|
||||||
pos++;
|
pos++;
|
||||||
}
|
}
|
||||||
NDArray arrayFull('c', {length}, sd::DataType::INT32);
|
NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext());
|
||||||
cContext.setOutputArray(0, &arrayFull);
|
cContext.setOutputArray(0, &arrayFull);
|
||||||
cContext.setIArguments(&axis, 1);
|
cContext.setIArguments(&axis, 1);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue