diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 71470ed09..ad8a781a6 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -349,7 +349,7 @@ elseif(CPU_BLAS) endif() endif() - if (CMAKE_BUILD_TYPE STREQUAL "Debug") + if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(APPLE)) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic") SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") endif() diff --git a/libnd4j/include/ops/declarable/generic/boolean/choose.cpp b/libnd4j/include/ops/declarable/generic/boolean/choose.cpp index 23652fde6..c0e6ee49f 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/choose.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/choose.cpp @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { - CUSTOM_OP_IMPL(choose, -1, 2, false, -1, -1) { + CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) { int mode = INT_ARG(0); auto result = OUTPUT_VARIABLE(0); @@ -61,6 +61,8 @@ namespace nd4j { DECLARE_SHAPE_FN(choose) { Nd4jLong *shape; int rank; + int mode = INT_ARG(0); + auto numResults = NDArrayFactory::create(0L); if(block.width() > 1) { auto first = INPUT_VARIABLE(0); auto second = INPUT_VARIABLE(1); @@ -72,18 +74,22 @@ namespace nd4j { shape = second->getShapeInfo(); rank = second->rankOf(); } + + helpers::chooseFunctorArray(block.launchContext(), first, second, mode, nullptr, &numResults); } else { auto first = INPUT_VARIABLE(0); shape = first->getShapeInfo(); rank = first->rankOf(); + double scalar = T_ARG(0); + + helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, nullptr, &numResults); } - Nd4jLong* newShape; - COPY_SHAPE(shape, newShape); + auto newShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(numResults.e(0), ArrayOptions::dataType(inputShape->at(0))); auto shapeScalar = ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64); - return SHAPELIST(CONSTANT(newShape), shapeScalar); + return SHAPELIST(newShape, shapeScalar); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space.cpp index 90d14243d..72b6b997b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space.cpp @@ -68,10 +68,10 @@ namespace ops { auto blocks = INPUT_VARIABLE(1); auto crops = INPUT_VARIABLE(2); - block_dims = (int) blocks->lengthOf(); + block_dims = (int) blocks->sizeAt(0); REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "BatchToSpace: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf()); - REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf() + 1, 0, "BatchToSpace: blocks length + 2 should match input rank at least"); + REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf(), 0, "BatchToSpace: blocks length + 1 should match input rank at least"); REQUIRE_TRUE(crops->rankOf() == 2, 0, "BatchToSpace: padding should have rank of 2, but got %i instead", crops->rankOf()); REQUIRE_TRUE(crops->columns() == 2 && blocks->lengthOf() == crops->rows(), 0, "BatchToSpace: padding should have M rows and 2 columns"); @@ -198,7 +198,7 @@ namespace ops { auto blocks = INPUT_VARIABLE(1); auto crops = INPUT_VARIABLE(2); - block_dims = (int) blocks->lengthOf(); + block_dims = (int) blocks->sizeAt(0); block_shape = blocks->template asVectorT(); crops_shape = crops->template asVectorT(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch.cpp index f88aa0091..fecf914f2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch.cpp @@ -75,7 +75,7 @@ namespace ops { block_shape.resize(block_dims); padding_shape.resize(M*2); - REQUIRE_TRUE(input->rankOf() >= 1 + M + 1, 0, "SpaceToBatch: blocks length + 2 should match input rank at least"); + REQUIRE_TRUE(input->rankOf() >= 1 + M, 0, "SpaceToBatch: blocks length + 1 should match input rank at least"); int e = 0; for (; e < block_dims; e++) diff --git a/libnd4j/include/ops/declarable/headers/boolean.h b/libnd4j/include/ops/declarable/headers/boolean.h index 44654485e..21fe48202 100644 --- a/libnd4j/include/ops/declarable/headers/boolean.h +++ b/libnd4j/include/ops/declarable/headers/boolean.h @@ -120,7 +120,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_choose) - DECLARE_CUSTOM_OP(choose, -1, 1, false, -1, -1); + DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1); #endif /** diff --git a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp index 100fec893..d18cde269 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp @@ -46,7 +46,8 @@ namespace helpers { for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), comp->e(0)); if(result2 > 0) { - output->p(numResults, arg->e(i)); + if (output != nullptr) + output->p(numResults, arg->e(i)); numResults++; } } @@ -56,9 +57,10 @@ namespace helpers { //for comparison nd4j::NDArray arg1 = *arg; for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { - T result2 = processElementCondition(mode, arg->e(i), compScalar.e(0)); + T result2 = processElementCondition(mode, arg->e(i), comp->e(i)); if(result2 > 0) { - output->p(numResults, arg->e(i)); + if (output != nullptr) + output->p(numResults, arg->e(i)); numResults++; } } @@ -72,7 +74,8 @@ namespace helpers { for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), compScalar.e(0)); if(result2 > 0) { - output->p(numResults, arg->e(i)); + if (output != nullptr) + output->p(numResults, arg->e(i)); numResults++; } } @@ -82,7 +85,6 @@ namespace helpers { numResult->p(0,numResults); return output; - } nd4j::NDArray* processCondition(nd4j::LaunchContext * context, int mode,nd4j::NDArray *arg, nd4j::NDArray *comp, nd4j::NDArray *output, nd4j::NDArray *numResult, nd4j::NDArray& compScalar) { diff --git a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp index be92a2ada..2cbc8513e 100644 --- a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp @@ -127,5 +127,23 @@ TEST_F(BooleanOpsTests, Is_numeric_tensor_1) { nd4j::ops::is_numeric_tensor op; ASSERT_TRUE(op.evaluate({&x})); - +} + +TEST_F(BooleanOpsTests, test_where_1) { + auto x = NDArrayFactory::create('c', {6}, { 1, -3, 4, 8, -2, 5 }); + auto y = NDArrayFactory::create('c', {6}, { 2, -3, 1, 1, -2, 1 }); + auto e = NDArrayFactory::create('c', {3}, { 4, 8, 5 }); + + nd4j:ops::choose op; + + auto result = op.execute({&x, &y}, {}, {3}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + z->printIndexedBuffer("z"); + + ASSERT_EQ(e, *z); + + delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index a37228aba..c0eba69f1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -113,7 +113,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) { ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); - ASSERT_EQ(4,z->lengthOf()); + ASSERT_EQ(3, z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; @@ -137,7 +137,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) { ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); - ASSERT_EQ(4,z->lengthOf()); + ASSERT_EQ(3,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; @@ -160,7 +160,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) { ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); - ASSERT_EQ(4,z->lengthOf()); + ASSERT_EQ(2,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; @@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) { ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); - ASSERT_EQ(4,z->lengthOf()); + ASSERT_EQ(3,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 285ba6d42..38e3b9523 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -209,7 +209,7 @@ file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/lo message("CPU BLAS") add_definitions(-D__CPUBLAS__=true) -if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW)) +if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE)) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic") SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") endif() diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java index cfbd4f29f..28151a899 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java @@ -69,7 +69,6 @@ public class Choose extends DynamicCustomOp { addInputArgument(inputs); addIArgument(condition.condtionNum()); - addOutputArgument(Nd4j.create(inputs[0].length()),Nd4j.scalar(1.0)); } /** @@ -106,8 +105,6 @@ public class Choose extends DynamicCustomOp { if(!tArgs.isEmpty()) addTArgument(Doubles.toArray(tArgs)); addIArgument(condition.condtionNum()); - - addOutputArgument(Nd4j.create(inputs[0].shape(), inputs[0].ordering()),Nd4j.scalar(DataType.LONG, 1.0)); } public Choose(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java index e6a51f0cb..6bc48778f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.indexing; import lombok.NonNull; +import lombok.val; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex; @@ -191,9 +192,9 @@ public class BooleanIndexing { * ffor the given conditions */ public static INDArray chooseFrom(@NonNull INDArray[] input,@NonNull Condition condition) { - Choose choose = new Choose(input,condition); - Nd4j.getExecutioner().execAndReturn(choose); - int secondOutput = choose.getOutputArgument(1).getInt(0); + val choose = new Choose(input,condition); + val outputs = Nd4j.exec(choose); + int secondOutput = outputs[1].getInt(0); if(secondOutput < 1) { return null; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index ad68aad50..a41a1e3c7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -2810,6 +2810,44 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps { Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, @Cast("bool") boolean descending); + public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + @Cast("bool") boolean descending); + + public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, @@ -2835,6 +2873,56 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps { @Cast("Nd4jLong*") long[] tadOffsets, @Cast("bool") boolean descending); + public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + IntPointer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + IntBuffer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + int[] dimension, + int dimensionLength, + @Cast("bool") boolean descending); + + public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + IntPointer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + IntBuffer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + int[] dimension, + int dimensionLength, + @Cast("bool") boolean descending); + // special sort impl for sorting out COO indices and values public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); @@ -9617,8 +9705,8 @@ public static final int PREALLOC_SIZE = 33554432; // #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) -public static final int ALL_INTS =INT64; -public static final int ALL_FLOATS =DOUBLE; +public static final int ALL_INTS =UINT64; +public static final int ALL_FLOATS =BFLOAT16; // #endif //TESTS_CPU_TYPE_BOILERPLATE_H diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index c672583f5..b069b2aef 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -7641,12 +7641,13 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(scalarRank2, scalarRank2.dup()); } - @Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632 + //@Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632 @Test public void testGetWhereINDArray() { INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 }); + INDArray comp = Nd4j.create(new double[]{2, -3, 1, 1, -2, 1 }); INDArray expected = Nd4j.create(new double[] { 4, 8, 5 }); - INDArray actual = input.getWhere(input, Conditions.greaterThan(1)); + INDArray actual = input.getWhere(comp, Conditions.greaterThan(1)); assertEquals(expected, actual); } @@ -7655,7 +7656,7 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testGetWhereNumber() { INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 }); INDArray expected = Nd4j.create(new double[] { 8, 5 }); - INDArray actual = input.getWhere(4, Conditions.greaterThan(6)); + INDArray actual = input.getWhere(4, Conditions.greaterThan(1)); assertEquals(expected, actual); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java index e9487fd15..12de3095e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java @@ -31,7 +31,6 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import sun.awt.image.DataBufferNative; import java.nio.ByteBuffer; import java.nio.ByteOrder;