Small fixes (#43)

* minor cmake changes to make macos happy

* space_to_batch/batch_to_space validation fix

* - choose op tweaks
- tests updated to match appleid tweaks

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* - get rid of bad import

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* - choose now uses shape function
- choose test updated
master
raver119 2019-07-03 14:24:50 +03:00 committed by AlexDBlack
parent 8b4b8977a0
commit dde50ee570
14 changed files with 145 additions and 33 deletions

View File

@ -349,7 +349,7 @@ elseif(CPU_BLAS)
endif()
endif()
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(APPLE))
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
endif()

View File

@ -27,7 +27,7 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(choose, -1, 2, false, -1, -1) {
CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) {
int mode = INT_ARG(0);
auto result = OUTPUT_VARIABLE(0);
@ -61,6 +61,8 @@ namespace nd4j {
DECLARE_SHAPE_FN(choose) {
Nd4jLong *shape;
int rank;
int mode = INT_ARG(0);
auto numResults = NDArrayFactory::create<Nd4jLong>(0L);
if(block.width() > 1) {
auto first = INPUT_VARIABLE(0);
auto second = INPUT_VARIABLE(1);
@ -72,18 +74,22 @@ namespace nd4j {
shape = second->getShapeInfo();
rank = second->rankOf();
}
helpers::chooseFunctorArray(block.launchContext(), first, second, mode, nullptr, &numResults);
}
else {
auto first = INPUT_VARIABLE(0);
shape = first->getShapeInfo();
rank = first->rankOf();
double scalar = T_ARG(0);
helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, nullptr, &numResults);
}
Nd4jLong* newShape;
COPY_SHAPE(shape, newShape);
auto newShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(numResults.e<Nd4jLong>(0), ArrayOptions::dataType(inputShape->at(0)));
auto shapeScalar = ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64);
return SHAPELIST(CONSTANT(newShape), shapeScalar);
return SHAPELIST(newShape, shapeScalar);
}

View File

@ -68,10 +68,10 @@ namespace ops {
auto blocks = INPUT_VARIABLE(1);
auto crops = INPUT_VARIABLE(2);
block_dims = (int) blocks->lengthOf();
block_dims = (int) blocks->sizeAt(0);
REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "BatchToSpace: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf());
REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf() + 1, 0, "BatchToSpace: blocks length + 2 should match input rank at least");
REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf(), 0, "BatchToSpace: blocks length + 1 should match input rank at least");
REQUIRE_TRUE(crops->rankOf() == 2, 0, "BatchToSpace: padding should have rank of 2, but got %i instead", crops->rankOf());
REQUIRE_TRUE(crops->columns() == 2 && blocks->lengthOf() == crops->rows(), 0, "BatchToSpace: padding should have M rows and 2 columns");
@ -198,7 +198,7 @@ namespace ops {
auto blocks = INPUT_VARIABLE(1);
auto crops = INPUT_VARIABLE(2);
block_dims = (int) blocks->lengthOf();
block_dims = (int) blocks->sizeAt(0);
block_shape = blocks->template asVectorT<int>();
crops_shape = crops->template asVectorT<int>();

View File

@ -75,7 +75,7 @@ namespace ops {
block_shape.resize(block_dims);
padding_shape.resize(M*2);
REQUIRE_TRUE(input->rankOf() >= 1 + M + 1, 0, "SpaceToBatch: blocks length + 2 should match input rank at least");
REQUIRE_TRUE(input->rankOf() >= 1 + M, 0, "SpaceToBatch: blocks length + 1 should match input rank at least");
int e = 0;
for (; e < block_dims; e++)

View File

@ -120,7 +120,7 @@ namespace nd4j {
* @tparam T
*/
#if NOT_EXCLUDED(OP_choose)
DECLARE_CUSTOM_OP(choose, -1, 1, false, -1, -1);
DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1);
#endif
/**

View File

@ -46,6 +46,7 @@ namespace helpers {
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(0));
if(result2 > 0) {
if (output != nullptr)
output->p(numResults, arg->e<T>(i));
numResults++;
}
@ -56,8 +57,9 @@ namespace helpers {
//for comparison
nd4j::NDArray arg1 = *arg;
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0));
T result2 = processElementCondition(mode, arg->e<T>(i), comp->e<T>(i));
if(result2 > 0) {
if (output != nullptr)
output->p(numResults, arg->e<T>(i));
numResults++;
}
@ -72,6 +74,7 @@ namespace helpers {
for (Nd4jLong i = 0; i < arg->lengthOf(); i++) {
T result2 = processElementCondition(mode, arg->e<T>(i), compScalar.e<T>(0));
if(result2 > 0) {
if (output != nullptr)
output->p(numResults, arg->e<T>(i));
numResults++;
}
@ -82,7 +85,6 @@ namespace helpers {
numResult->p(0,numResults);
return output;
}
nd4j::NDArray* processCondition(nd4j::LaunchContext * context, int mode,nd4j::NDArray *arg, nd4j::NDArray *comp, nd4j::NDArray *output, nd4j::NDArray *numResult, nd4j::NDArray& compScalar) {

View File

@ -127,5 +127,23 @@ TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
nd4j::ops::is_numeric_tensor op;
ASSERT_TRUE(op.evaluate({&x}));
}
TEST_F(BooleanOpsTests, test_where_1) {
auto x = NDArrayFactory::create<double>('c', {6}, { 1, -3, 4, 8, -2, 5 });
auto y = NDArrayFactory::create<double>('c', {6}, { 2, -3, 1, 1, -2, 1 });
auto e = NDArrayFactory::create<double>('c', {3}, { 4, 8, 5 });
nd4j:ops::choose op;
auto result = op.execute({&x, &y}, {}, {3});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("z");
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -113,7 +113,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) {
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf());
ASSERT_EQ(3, z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z));
delete result;
@ -137,7 +137,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) {
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf());
ASSERT_EQ(3,z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z));
delete result;
@ -160,7 +160,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) {
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf());
ASSERT_EQ(2,z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z));
delete result;
@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) {
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(4,z->lengthOf());
ASSERT_EQ(3,z->lengthOf());
//ASSERT_TRUE(exp.isSameShape(z));
delete result;

View File

@ -209,7 +209,7 @@ file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/lo
message("CPU BLAS")
add_definitions(-D__CPUBLAS__=true)
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW))
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE))
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
endif()

View File

@ -69,7 +69,6 @@ public class Choose extends DynamicCustomOp {
addInputArgument(inputs);
addIArgument(condition.condtionNum());
addOutputArgument(Nd4j.create(inputs[0].length()),Nd4j.scalar(1.0));
}
/**
@ -106,8 +105,6 @@ public class Choose extends DynamicCustomOp {
if(!tArgs.isEmpty())
addTArgument(Doubles.toArray(tArgs));
addIArgument(condition.condtionNum());
addOutputArgument(Nd4j.create(inputs[0].shape(), inputs[0].ordering()),Nd4j.scalar(DataType.LONG, 1.0));
}
public Choose(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace) {

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.indexing;
import lombok.NonNull;
import lombok.val;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
@ -191,9 +192,9 @@ public class BooleanIndexing {
* ffor the given conditions
*/
public static INDArray chooseFrom(@NonNull INDArray[] input,@NonNull Condition condition) {
Choose choose = new Choose(input,condition);
Nd4j.getExecutioner().execAndReturn(choose);
int secondOutput = choose.getOutputArgument(1).getInt(0);
val choose = new Choose(input,condition);
val outputs = Nd4j.exec(choose);
int secondOutput = outputs[1].getInt(0);
if(secondOutput < 1) {
return null;
}

View File

@ -2810,6 +2810,44 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
@ -2835,6 +2873,56 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
@Cast("Nd4jLong*") long[] tadOffsets,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
IntPointer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
IntBuffer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
int[] dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
IntPointer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
IntBuffer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
int[] dimension,
int dimensionLength,
@Cast("bool") boolean descending);
// special sort impl for sorting out COO indices and values
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
@ -9617,8 +9705,8 @@ public static final int PREALLOC_SIZE = 33554432;
// #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
public static final int ALL_INTS =INT64;
public static final int ALL_FLOATS =DOUBLE;
public static final int ALL_INTS =UINT64;
public static final int ALL_FLOATS =BFLOAT16;
// #endif //TESTS_CPU_TYPE_BOILERPLATE_H

View File

@ -7641,12 +7641,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(scalarRank2, scalarRank2.dup());
}
@Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632
//@Ignore // https://github.com/deeplearning4j/deeplearning4j/issues/7632
@Test
public void testGetWhereINDArray() {
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
INDArray comp = Nd4j.create(new double[]{2, -3, 1, 1, -2, 1 });
INDArray expected = Nd4j.create(new double[] { 4, 8, 5 });
INDArray actual = input.getWhere(input, Conditions.greaterThan(1));
INDArray actual = input.getWhere(comp, Conditions.greaterThan(1));
assertEquals(expected, actual);
}
@ -7655,7 +7656,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
public void testGetWhereNumber() {
INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 });
INDArray expected = Nd4j.create(new double[] { 8, 5 });
INDArray actual = input.getWhere(4, Conditions.greaterThan(6));
INDArray actual = input.getWhere(4, Conditions.greaterThan(1));
assertEquals(expected, actual);
}

View File

@ -31,7 +31,6 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import sun.awt.image.DataBufferNative;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;