Small fixes (#43)

* minor cmake changes to make macos happy

* space_to_batch/batch_to_space validation fix

* - choose op tweaks
- tests updated to match appleid tweaks

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* - get rid of bad import

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* - choose now uses shape function
- choose test updated
master
raver119 2019-07-03 14:24:50 +03:00 committed by AlexDBlack
parent 8b4b8977a0
commit dde50ee570
14 changed files with 145 additions and 33 deletions

View File

@ -349,7 +349,7 @@ elseif(CPU_BLAS)
endif() endif()
endif() endif()
if (CMAKE_BUILD_TYPE STREQUAL "Debug") if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(APPLE))
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
endif() endif()

View File

@ -27,7 +27,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(choose, -1, 2, false, -1, -1) { CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) {
int mode = INT_ARG(0); int mode = INT_ARG(0);
auto result = OUTPUT_VARIABLE(0); auto result = OUTPUT_VARIABLE(0);
@ -61,6 +61,8 @@ namespace nd4j {
DECLARE_SHAPE_FN(choose) { DECLARE_SHAPE_FN(choose) {
Nd4jLong *shape; Nd4jLong *shape;
int rank; int rank;
int mode = INT_ARG(0);
auto numResults = NDArrayFactory::create<Nd4jLong>(0L);
if(block.width() > 1) { if(block.width() > 1) {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto second = INPUT_VARIABLE(1); auto second = INPUT_VARIABLE(1);
@ -72,18 +74,22 @@ namespace nd4j {
shape = second->getShapeInfo(); shape = second->getShapeInfo();
rank = second->rankOf(); rank = second->rankOf();
} }
helpers::chooseFunctorArray(block.launchContext(), first, second, mode, nullptr, &numResults);
} }
else { else {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
shape = first->getShapeInfo(); shape = first->getShapeInfo();
rank = first->rankOf(); rank = first->rankOf();
double scalar = T_ARG(0);
helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, nullptr, &numResults);
} }
Nd4jLong* newShape; auto newShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(numResults.e<Nd4jLong>(0), ArrayOptions::dataType(inputShape->at(0)));
COPY_SHAPE(shape, newShape);
auto shapeScalar = ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64); auto shapeScalar = ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64);
return SHAPELIST(CONSTANT(newShape), shapeScalar); return SHAPELIST(newShape, shapeScalar);
} }

View File

@ -68,10 +68,10 @@ namespace ops {
auto blocks = INPUT_VARIABLE(1); auto blocks = INPUT_VARIABLE(1);
auto crops = INPUT_VARIABLE(2); auto crops = INPUT_VARIABLE(2);
block_dims = (int) blocks->lengthOf(); block_dims = (int) blocks->sizeAt(0);
REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "BatchToSpace: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf()); REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "BatchToSpace: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf());
REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf() + 1, 0, "BatchToSpace: blocks length + 2 should match input rank at least"); REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf(), 0, "BatchToSpace: blocks length + 1 should match input rank at least");
REQUIRE_TRUE(crops->rankOf() == 2, 0, "BatchToSpace: padding should have rank of 2, but got %i instead", crops->rankOf()); REQUIRE_TRUE(crops->rankOf() == 2, 0, "BatchToSpace: padding should have rank of 2, but got %i instead", crops->rankOf());
REQUIRE_TRUE(crops->columns() == 2 && blocks->lengthOf() == crops->rows(), 0, "BatchToSpace: padding should have M rows and 2 columns"); REQUIRE_TRUE(crops->columns() == 2 && blocks->lengthOf() == crops->rows(), 0, "BatchToSpace: padding should have M rows and 2 columns");
@ -198,7 +198,7 @@ namespace ops {
auto blocks = INPUT_VARIABLE(1); auto blocks = INPUT_VARIABLE(1);
auto crops = INPUT_VARIABLE(2); auto crops = INPUT_VARIABLE(2);
block_dims = (int) blocks->lengthOf(); block_dims = (int) blocks->sizeAt(0);
block_shape = blocks->template asVectorT<int>(); block_shape = blocks->template asVectorT<int>();
crops_shape = crops->template asVectorT<int>(); crops_shape = crops->template asVectorT<int>();

View File

@ -75,7 +75,7 @@ namespace ops {
block_shape.resize(block_dims); block_shape.resize(block_dims);
padding_shape.resize(M*2); padding_shape.resize(M*2);
REQUIRE_TRUE(input->rankOf() >= 1 + M + 1, 0, "SpaceToBatch: blocks length + 2 should match input rank at least"); REQUIRE_TRUE(input->rankOf() >= 1 + M, 0, "SpaceToBatch: blocks length + 1 should match input rank at least");
int e = 0; int e = 0;
for (; e < block_dims; e++) for (; e < block_dims; e++)

View File

@ -120,7 +120,7 @@ namespace nd4j {
* @tparam T * @tparam T
*/ */
#if NOT_EXCLUDED(OP_choose) #if NOT_EXCLUDED(OP_choose)
DECLARE_CUSTOM_OP(choose, -1, 1, false, -1, -1); DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1);
#endif #endif
/** /**

View File

@ -46,7 +46,8 @@ namespace helpers {
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(0)); T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(0));
if(result2 > 0) { if(result2 > 0) {
output->p(numResults, arg->e<T>(i)); if (output != nullptr)
output->p(numResults, arg->e<T>(i));
numResults++; numResults++;
} }
} }
@ -56,9 +57,10 @@ namespace helpers {
//for comparison //for comparison
nd4j::NDArray arg1 = *arg; nd4j::NDArray arg1 = *arg;
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0)); T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(i));
if(result2 > 0) { if(result2 > 0) {
output->p(numResults, arg->e<T>(i)); if (output != nullptr)
output->p(numResults, arg->e<T>(i));
numResults++; numResults++;
} }
} }
@ -72,7 +74,8 @@ namespace helpers {
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0)); T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0));
if(result2 > 0) { if(result2 > 0) {
output->p(numResults, arg->e<T>(i)); if (output != nullptr)
output->p(numResults, arg->e<T>(i));
numResults++; numResults++;
} }
} }
@ -82,7 +85,6 @@ namespace helpers {
numResult->p(0,numResults); numResult->p(0,numResults);
return output; return output;
} }
nd4j::NDArray* processCondition(nd4j::LaunchContext * context, int mode,nd4j::NDArray *arg, nd4j::NDArray *comp, nd4j::NDArray *output, nd4j::NDArray *numResult, nd4j::NDArray& compScalar) { nd4j::NDArray* processCondition(nd4j::LaunchContext * context, int mode,nd4j::NDArray *arg, nd4j::NDArray *comp, nd4j::NDArray *output, nd4j::NDArray *numResult, nd4j::NDArray& compScalar) {

View File

@ -127,5 +127,23 @@ TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
nd4j::ops::is_numeric_tensor op; nd4j::ops::is_numeric_tensor op;
ASSERT_TRUE(op.evaluate({&x})); ASSERT_TRUE(op.evaluate({&x}));
}
TEST_F(BooleanOpsTests, test_where_1) {
auto x = NDArrayFactory::create<double>('c', {6}, { 1, -3, 4, 8, -2, 5 });
auto y = NDArrayFactory::create<double>('c', {6}, { 2, -3, 1, 1, -2, 1 });
auto e = NDArrayFactory::create<double>('c', {3}, { 4, 8, 5 });
nd4j:ops::choose op;
auto result = op.execute({&x, &y}, {}, {3});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("z");
ASSERT_EQ(e, *z);
delete result;
} }

View File

@ -113,7 +113,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) {
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf()); ASSERT_EQ(3, z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z)); //ASSERT_TRUE(exp.isSameShape(z));
delete result; delete result;
@ -137,7 +137,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) {
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf()); ASSERT_EQ(3,z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z)); //ASSERT_TRUE(exp.isSameShape(z));
delete result; delete result;
@ -160,7 +160,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) {
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf()); ASSERT_EQ(2,z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z)); //ASSERT_TRUE(exp.isSameShape(z));
delete result; delete result;
@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) {
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf()); ASSERT_EQ(3,z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z)); //ASSERT_TRUE(exp.isSameShape(z));
delete result; delete result;

View File

@ -209,7 +209,7 @@ file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/lo
message("CPU BLAS") message("CPU BLAS")
add_definitions(-D__CPUBLAS__=true) add_definitions(-D__CPUBLAS__=true)
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW)) if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
endif() endif()

View File

@ -69,7 +69,6 @@ public class Choose extends DynamicCustomOp {
addInputArgument(inputs); addInputArgument(inputs);
addIArgument(condition.condtionNum()); addIArgument(condition.condtionNum());
addOutputArgument(Nd4j.create(inputs[0].length()),Nd4j.scalar(1.0));
} }
/** /**
@ -106,8 +105,6 @@ public class Choose extends DynamicCustomOp {
if(!tArgs.isEmpty()) if(!tArgs.isEmpty())
addTArgument(Doubles.toArray(tArgs)); addTArgument(Doubles.toArray(tArgs));
addIArgument(condition.condtionNum()); addIArgument(condition.condtionNum());
addOutputArgument(Nd4j.create(inputs[0].shape(), inputs[0].ordering()),Nd4j.scalar(DataType.LONG, 1.0));
} }
public Choose(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace) { public Choose(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace) {

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.indexing; package org.nd4j.linalg.indexing;
import lombok.NonNull; import lombok.NonNull;
import lombok.val;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex; import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
@ -191,9 +192,9 @@ public class BooleanIndexing {
* ffor the given conditions * ffor the given conditions
*/ */
public static INDArray chooseFrom(@NonNull INDArray[] input,@NonNull Condition condition) { public static INDArray chooseFrom(@NonNull INDArray[] input,@NonNull Condition condition) {
Choose choose = new Choose(input,condition); val choose = new Choose(input,condition);
Nd4j.getExecutioner().execAndReturn(choose); val outputs = Nd4j.exec(choose);
int secondOutput = choose.getOutputArgument(1).getInt(0); int secondOutput = outputs[1].getInt(0);
if(secondOutput < 1) { if(secondOutput < 1) {
return null; return null;
} }

View File

@ -2810,6 +2810,44 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
@Cast("bool") boolean descending); @Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
@ -2835,6 +2873,56 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
@Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadOffsets,
@Cast("bool") boolean descending); @Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
IntPointer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
IntBuffer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
int[] dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
IntPointer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
IntBuffer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
int[] dimension,
int dimensionLength,
@Cast("bool") boolean descending);
// special sort impl for sorting out COO indices and values // special sort impl for sorting out COO indices and values
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
@ -9617,8 +9705,8 @@ public static final int PREALLOC_SIZE = 33554432;
// #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) // #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
public static final int ALL_INTS =INT64; public static final int ALL_INTS =UINT64;
public static final int ALL_FLOATS =DOUBLE; public static final int ALL_FLOATS =BFLOAT16;
// #endif //TESTS_CPU_TYPE_BOILERPLATE_H // #endif //TESTS_CPU_TYPE_BOILERPLATE_H

View File

@ -7641,12 +7641,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(scalarRank2, scalarRank2.dup()); assertEquals(scalarRank2, scalarRank2.dup());
} }
@Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632 //@Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632
@Test @Test
public void testGetWhereINDArray() { public void testGetWhereINDArray() {
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 }); INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
INDArray comp = Nd4j.create(new double[]{2, -3, 1, 1, -2, 1 });
INDArray expected = Nd4j.create(new double[] { 4, 8, 5 }); INDArray expected = Nd4j.create(new double[] { 4, 8, 5 });
INDArray actual = input.getWhere(input, Conditions.greaterThan(1)); INDArray actual = input.getWhere(comp, Conditions.greaterThan(1));
assertEquals(expected, actual); assertEquals(expected, actual);
} }
@ -7655,7 +7656,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
public void testGetWhereNumber() { public void testGetWhereNumber() {
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 }); INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
INDArray expected = Nd4j.create(new double[] { 8, 5 }); INDArray expected = Nd4j.create(new double[] { 8, 5 });
INDArray actual = input.getWhere(4, Conditions.greaterThan(6)); INDArray actual = input.getWhere(4, Conditions.greaterThan(1));
assertEquals(expected, actual); assertEquals(expected, actual);
} }

View File

@ -31,7 +31,6 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import sun.awt.image.DataBufferNative;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;