Small fixes (#43)
* minor cmake changes to make macos happy * space_to_batch/batch_to_space validation fix * - choose op tweaks - tests updated to match appleid tweaks Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - get rid of bad import Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - choose now uses shape function - choose test updatedmaster
parent
8b4b8977a0
commit
dde50ee570
|
@ -349,7 +349,7 @@ elseif(CPU_BLAS)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(APPLE))
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
|
||||||
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(choose, -1, 2, false, -1, -1) {
|
CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) {
|
||||||
|
|
||||||
int mode = INT_ARG(0);
|
int mode = INT_ARG(0);
|
||||||
auto result = OUTPUT_VARIABLE(0);
|
auto result = OUTPUT_VARIABLE(0);
|
||||||
|
@ -61,6 +61,8 @@ namespace nd4j {
|
||||||
DECLARE_SHAPE_FN(choose) {
|
DECLARE_SHAPE_FN(choose) {
|
||||||
Nd4jLong *shape;
|
Nd4jLong *shape;
|
||||||
int rank;
|
int rank;
|
||||||
|
int mode = INT_ARG(0);
|
||||||
|
auto numResults = NDArrayFactory::create<Nd4jLong>(0L);
|
||||||
if(block.width() > 1) {
|
if(block.width() > 1) {
|
||||||
auto first = INPUT_VARIABLE(0);
|
auto first = INPUT_VARIABLE(0);
|
||||||
auto second = INPUT_VARIABLE(1);
|
auto second = INPUT_VARIABLE(1);
|
||||||
|
@ -72,18 +74,22 @@ namespace nd4j {
|
||||||
shape = second->getShapeInfo();
|
shape = second->getShapeInfo();
|
||||||
rank = second->rankOf();
|
rank = second->rankOf();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
helpers::chooseFunctorArray(block.launchContext(), first, second, mode, nullptr, &numResults);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto first = INPUT_VARIABLE(0);
|
auto first = INPUT_VARIABLE(0);
|
||||||
shape = first->getShapeInfo();
|
shape = first->getShapeInfo();
|
||||||
rank = first->rankOf();
|
rank = first->rankOf();
|
||||||
|
double scalar = T_ARG(0);
|
||||||
|
|
||||||
|
helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, nullptr, &numResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* newShape;
|
auto newShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(numResults.e<Nd4jLong>(0), ArrayOptions::dataType(inputShape->at(0)));
|
||||||
COPY_SHAPE(shape, newShape);
|
|
||||||
|
|
||||||
auto shapeScalar = ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64);
|
auto shapeScalar = ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64);
|
||||||
return SHAPELIST(CONSTANT(newShape), shapeScalar);
|
return SHAPELIST(newShape, shapeScalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -68,10 +68,10 @@ namespace ops {
|
||||||
auto blocks = INPUT_VARIABLE(1);
|
auto blocks = INPUT_VARIABLE(1);
|
||||||
auto crops = INPUT_VARIABLE(2);
|
auto crops = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
block_dims = (int) blocks->lengthOf();
|
block_dims = (int) blocks->sizeAt(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "BatchToSpace: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf());
|
REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "BatchToSpace: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf());
|
||||||
REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf() + 1, 0, "BatchToSpace: blocks length + 2 should match input rank at least");
|
REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf(), 0, "BatchToSpace: blocks length + 1 should match input rank at least");
|
||||||
REQUIRE_TRUE(crops->rankOf() == 2, 0, "BatchToSpace: padding should have rank of 2, but got %i instead", crops->rankOf());
|
REQUIRE_TRUE(crops->rankOf() == 2, 0, "BatchToSpace: padding should have rank of 2, but got %i instead", crops->rankOf());
|
||||||
REQUIRE_TRUE(crops->columns() == 2 && blocks->lengthOf() == crops->rows(), 0, "BatchToSpace: padding should have M rows and 2 columns");
|
REQUIRE_TRUE(crops->columns() == 2 && blocks->lengthOf() == crops->rows(), 0, "BatchToSpace: padding should have M rows and 2 columns");
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ namespace ops {
|
||||||
auto blocks = INPUT_VARIABLE(1);
|
auto blocks = INPUT_VARIABLE(1);
|
||||||
auto crops = INPUT_VARIABLE(2);
|
auto crops = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
block_dims = (int) blocks->lengthOf();
|
block_dims = (int) blocks->sizeAt(0);
|
||||||
|
|
||||||
block_shape = blocks->template asVectorT<int>();
|
block_shape = blocks->template asVectorT<int>();
|
||||||
crops_shape = crops->template asVectorT<int>();
|
crops_shape = crops->template asVectorT<int>();
|
||||||
|
|
|
@ -75,7 +75,7 @@ namespace ops {
|
||||||
block_shape.resize(block_dims);
|
block_shape.resize(block_dims);
|
||||||
padding_shape.resize(M*2);
|
padding_shape.resize(M*2);
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() >= 1 + M + 1, 0, "SpaceToBatch: blocks length + 2 should match input rank at least");
|
REQUIRE_TRUE(input->rankOf() >= 1 + M, 0, "SpaceToBatch: blocks length + 1 should match input rank at least");
|
||||||
|
|
||||||
int e = 0;
|
int e = 0;
|
||||||
for (; e < block_dims; e++)
|
for (; e < block_dims; e++)
|
||||||
|
|
|
@ -120,7 +120,7 @@ namespace nd4j {
|
||||||
* @tparam T
|
* @tparam T
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_choose)
|
#if NOT_EXCLUDED(OP_choose)
|
||||||
DECLARE_CUSTOM_OP(choose, -1, 1, false, -1, -1);
|
DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -46,6 +46,7 @@ namespace helpers {
|
||||||
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
|
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
|
||||||
T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(0));
|
T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(0));
|
||||||
if(result2 > 0) {
|
if(result2 > 0) {
|
||||||
|
if (output != nullptr)
|
||||||
output->p(numResults, arg->e<T>(i));
|
output->p(numResults, arg->e<T>(i));
|
||||||
numResults++;
|
numResults++;
|
||||||
}
|
}
|
||||||
|
@ -56,8 +57,9 @@ namespace helpers {
|
||||||
//for comparison
|
//for comparison
|
||||||
nd4j::NDArray arg1 = *arg;
|
nd4j::NDArray arg1 = *arg;
|
||||||
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
|
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
|
||||||
T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0));
|
T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(i));
|
||||||
if(result2 > 0) {
|
if(result2 > 0) {
|
||||||
|
if (output != nullptr)
|
||||||
output->p(numResults, arg->e<T>(i));
|
output->p(numResults, arg->e<T>(i));
|
||||||
numResults++;
|
numResults++;
|
||||||
}
|
}
|
||||||
|
@ -72,6 +74,7 @@ namespace helpers {
|
||||||
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
|
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
|
||||||
T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0));
|
T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0));
|
||||||
if(result2 > 0) {
|
if(result2 > 0) {
|
||||||
|
if (output != nullptr)
|
||||||
output->p(numResults, arg->e<T>(i));
|
output->p(numResults, arg->e<T>(i));
|
||||||
numResults++;
|
numResults++;
|
||||||
}
|
}
|
||||||
|
@ -82,7 +85,6 @@ namespace helpers {
|
||||||
numResult->p(0,numResults);
|
numResult->p(0,numResults);
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::NDArray* processCondition(nd4j::LaunchContext * context, int mode,nd4j::NDArray *arg, nd4j::NDArray *comp, nd4j::NDArray *output, nd4j::NDArray *numResult, nd4j::NDArray& compScalar) {
|
nd4j::NDArray* processCondition(nd4j::LaunchContext * context, int mode,nd4j::NDArray *arg, nd4j::NDArray *comp, nd4j::NDArray *output, nd4j::NDArray *numResult, nd4j::NDArray& compScalar) {
|
||||||
|
|
|
@ -127,5 +127,23 @@ TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
|
||||||
nd4j::ops::is_numeric_tensor op;
|
nd4j::ops::is_numeric_tensor op;
|
||||||
|
|
||||||
ASSERT_TRUE(op.evaluate({&x}));
|
ASSERT_TRUE(op.evaluate({&x}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, test_where_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {6}, { 1, -3, 4, 8, -2, 5 });
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {6}, { 2, -3, 1, 1, -2, 1 });
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {3}, { 4, 8, 5 });
|
||||||
|
|
||||||
|
nd4j:ops::choose op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x, &y}, {}, {3});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
z->printIndexedBuffer("z");
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,7 +113,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_EQ(4,z->lengthOf());
|
ASSERT_EQ(3, z->lengthOf());
|
||||||
//ASSERT_TRUE(exp.isSameShape(z));
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -137,7 +137,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_EQ(4,z->lengthOf());
|
ASSERT_EQ(3,z->lengthOf());
|
||||||
//ASSERT_TRUE(exp.isSameShape(z));
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -160,7 +160,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_EQ(4,z->lengthOf());
|
ASSERT_EQ(2,z->lengthOf());
|
||||||
//ASSERT_TRUE(exp.isSameShape(z));
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
ASSERT_EQ(4,z->lengthOf());
|
ASSERT_EQ(3,z->lengthOf());
|
||||||
//ASSERT_TRUE(exp.isSameShape(z));
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
|
@ -209,7 +209,7 @@ file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/lo
|
||||||
message("CPU BLAS")
|
message("CPU BLAS")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
|
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW))
|
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
|
||||||
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -69,7 +69,6 @@ public class Choose extends DynamicCustomOp {
|
||||||
|
|
||||||
addInputArgument(inputs);
|
addInputArgument(inputs);
|
||||||
addIArgument(condition.condtionNum());
|
addIArgument(condition.condtionNum());
|
||||||
addOutputArgument(Nd4j.create(inputs[0].length()),Nd4j.scalar(1.0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -106,8 +105,6 @@ public class Choose extends DynamicCustomOp {
|
||||||
if(!tArgs.isEmpty())
|
if(!tArgs.isEmpty())
|
||||||
addTArgument(Doubles.toArray(tArgs));
|
addTArgument(Doubles.toArray(tArgs));
|
||||||
addIArgument(condition.condtionNum());
|
addIArgument(condition.condtionNum());
|
||||||
|
|
||||||
addOutputArgument(Nd4j.create(inputs[0].shape(), inputs[0].ordering()),Nd4j.scalar(DataType.LONG, 1.0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Choose(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
public Choose(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.indexing;
|
package org.nd4j.linalg.indexing;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
import lombok.val;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
|
||||||
|
@ -191,9 +192,9 @@ public class BooleanIndexing {
|
||||||
* ffor the given conditions
|
* ffor the given conditions
|
||||||
*/
|
*/
|
||||||
public static INDArray chooseFrom(@NonNull INDArray[] input,@NonNull Condition condition) {
|
public static INDArray chooseFrom(@NonNull INDArray[] input,@NonNull Condition condition) {
|
||||||
Choose choose = new Choose(input,condition);
|
val choose = new Choose(input,condition);
|
||||||
Nd4j.getExecutioner().execAndReturn(choose);
|
val outputs = Nd4j.exec(choose);
|
||||||
int secondOutput = choose.getOutputArgument(1).getInt(0);
|
int secondOutput = outputs[1].getInt(0);
|
||||||
if(secondOutput < 1) {
|
if(secondOutput < 1) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2810,6 +2810,44 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
||||||
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||||
@Cast("bool") boolean descending);
|
@Cast("bool") boolean descending);
|
||||||
|
|
||||||
|
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
|
||||||
|
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
|
||||||
public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||||
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||||
|
@ -2835,6 +2873,56 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
||||||
@Cast("Nd4jLong*") long[] tadOffsets,
|
@Cast("Nd4jLong*") long[] tadOffsets,
|
||||||
@Cast("bool") boolean descending);
|
@Cast("bool") boolean descending);
|
||||||
|
|
||||||
|
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||||
|
IntPointer dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||||
|
IntBuffer dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||||
|
int[] dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
|
||||||
|
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||||
|
IntPointer dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||||
|
IntBuffer dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||||
|
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||||
|
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||||
|
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||||
|
int[] dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
@Cast("bool") boolean descending);
|
||||||
|
|
||||||
|
|
||||||
// special sort impl for sorting out COO indices and values
|
// special sort impl for sorting out COO indices and values
|
||||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
||||||
|
@ -9617,8 +9705,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
// #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
|
// #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
|
||||||
|
|
||||||
|
|
||||||
public static final int ALL_INTS =INT64;
|
public static final int ALL_INTS =UINT64;
|
||||||
public static final int ALL_FLOATS =DOUBLE;
|
public static final int ALL_FLOATS =BFLOAT16;
|
||||||
|
|
||||||
// #endif //TESTS_CPU_TYPE_BOILERPLATE_H
|
// #endif //TESTS_CPU_TYPE_BOILERPLATE_H
|
||||||
|
|
||||||
|
|
|
@ -7641,12 +7641,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(scalarRank2, scalarRank2.dup());
|
assertEquals(scalarRank2, scalarRank2.dup());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632
|
//@Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632
|
||||||
@Test
|
@Test
|
||||||
public void testGetWhereINDArray() {
|
public void testGetWhereINDArray() {
|
||||||
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
|
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
|
||||||
|
INDArray comp = Nd4j.create(new double[]{2, -3, 1, 1, -2, 1 });
|
||||||
INDArray expected = Nd4j.create(new double[] { 4, 8, 5 });
|
INDArray expected = Nd4j.create(new double[] { 4, 8, 5 });
|
||||||
INDArray actual = input.getWhere(input, Conditions.greaterThan(1));
|
INDArray actual = input.getWhere(comp, Conditions.greaterThan(1));
|
||||||
|
|
||||||
assertEquals(expected, actual);
|
assertEquals(expected, actual);
|
||||||
}
|
}
|
||||||
|
@ -7655,7 +7656,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
public void testGetWhereNumber() {
|
public void testGetWhereNumber() {
|
||||||
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
|
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
|
||||||
INDArray expected = Nd4j.create(new double[] { 8, 5 });
|
INDArray expected = Nd4j.create(new double[] { 8, 5 });
|
||||||
INDArray actual = input.getWhere(4, Conditions.greaterThan(6));
|
INDArray actual = input.getWhere(4, Conditions.greaterThan(1));
|
||||||
|
|
||||||
assertEquals(expected, actual);
|
assertEquals(expected, actual);
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,6 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import sun.awt.image.DataBufferNative;
|
|
||||||
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
|
|
Loading…
Reference in New Issue