[WIP] size etc (#155)

* one test for size

Signed-off-by: raver119 <raver119@gmail.com>

* - few tests for size op
- size/rank/size_at ops now use p instead of assign

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 12:31:12 +03:00 committed by GitHub
parent e22a2c93ff
commit 729dc5e879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 83 additions and 6 deletions

View File

@ -87,6 +87,10 @@ namespace nd4j {
std::map<int, std::vector<nd4j::DataType>> _outputTypes; std::map<int, std::vector<nd4j::DataType>> _outputTypes;
std::map<int, std::vector<nd4j::DataType>> _inputTypes; std::map<int, std::vector<nd4j::DataType>> _inputTypes;
// field for ops that allow data type override at runtime
bool _dtypeOverride = false;
bool checkDataTypesMatch(nd4j::DataType needle, std::vector<nd4j::DataType> &haystack) const; bool checkDataTypesMatch(nd4j::DataType needle, std::vector<nd4j::DataType> &haystack) const;
public: public:
// default constructor // default constructor
@ -164,6 +168,7 @@ namespace nd4j {
OpDescriptor* setAllowedOutputTypes(int index, nd4j::DataType dtype); OpDescriptor* setAllowedOutputTypes(int index, nd4j::DataType dtype);
OpDescriptor* setAllowedInputTypes(nd4j::DataType dtype); OpDescriptor* setAllowedInputTypes(nd4j::DataType dtype);
OpDescriptor* setAllowedOutputTypes(nd4j::DataType dtype); OpDescriptor* setAllowedOutputTypes(nd4j::DataType dtype);
OpDescriptor* allowOverride(bool reallyAllow);
OpDescriptor* setSameMode(bool reallySame); OpDescriptor* setSameMode(bool reallySame);
OpDescriptor* setInputType(int idx, nd4j::DataType dtype); OpDescriptor* setInputType(int idx, nd4j::DataType dtype);
OpDescriptor* setOutputType(int idx, nd4j::DataType dtype); OpDescriptor* setOutputType(int idx, nd4j::DataType dtype);

View File

@ -31,7 +31,8 @@ namespace nd4j {
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
output->assign(input->rankOf()); output->p(0, input->rankOf());
output->syncToDevice();
return Status::OK(); return Status::OK();
} }
@ -43,7 +44,8 @@ namespace nd4j {
DECLARE_TYPES(rank) { DECLARE_TYPES(rank) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INTS}); ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
->allowOverride(true);
} }
} }
} }

View File

@ -31,7 +31,8 @@ namespace nd4j {
REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar");
output->assign(input->lengthOf()); output->p(0, input->lengthOf());
output->syncToDevice();
return Status::OK(); return Status::OK();
} }
@ -42,7 +43,8 @@ namespace nd4j {
DECLARE_TYPES(size) { DECLARE_TYPES(size) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INTS}); ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
->allowOverride(true);
} }
} }
} }

View File

@ -35,7 +35,8 @@ namespace nd4j {
REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank") REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank")
output->assign(input->sizeAt(dim)); output->p(0, input->sizeAt(dim));
output->syncToDevice();
return Status::OK(); return Status::OK();
} }
@ -47,7 +48,8 @@ namespace nd4j {
DECLARE_TYPES(size_at) { DECLARE_TYPES(size_at) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(DataType::INT64); ->setAllowedOutputTypes(DataType::INT64)
->allowOverride(true);
} }
} }
} }

View File

@ -166,6 +166,11 @@ namespace nd4j {
return this; return this;
} }
OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) {
_dtypeOverride = allowOverride;
return this;
}
OpDescriptor* OpDescriptor::setAllowedInputTypes(const nd4j::DataType dtype) { OpDescriptor* OpDescriptor::setAllowedInputTypes(const nd4j::DataType dtype) {
_allowedIns.clear(); _allowedIns.clear();
_allowedIns.emplace_back(dtype); _allowedIns.emplace_back(dtype);

View File

@ -54,4 +54,16 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result; delete result;
}
TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
auto z = NDArrayFactory::create<float>(0.0f);
auto e = NDArrayFactory::create<float>(3.0f);
nd4j::ops::size op;
auto status = op.execute({&x}, {&z}, {}, {}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
} }

View File

@ -1189,6 +1189,22 @@ TEST_F(JavaInteropTests, test_ismax_view) {
delete t; delete t;
} }
TEST_F(JavaInteropTests, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
auto z = NDArrayFactory::create<float>(0.0f);
auto e = NDArrayFactory::create<float>(3.0f);
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::size op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
/* /*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) { TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");

View File

@ -9272,6 +9272,7 @@ public static final int PREALLOC_SIZE = 33554432;
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);

View File

@ -11557,6 +11557,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);

View File

@ -44,6 +44,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -641,4 +642,34 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(result1, result2); assertEquals(result1, result2);
} }
@Test
public void testSizeTypes(){
List<DataType> failed = new ArrayList<>();
for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE,
DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE,
DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16}) {
INDArray in = Nd4j.create(DataType.FLOAT, 100);
INDArray out = Nd4j.scalar(dt, 0);
INDArray e = Nd4j.scalar(dt, 100);
DynamicCustomOp op = DynamicCustomOp.builder("size")
.addInputs(in)
.addOutputs(out)
.build();
try {
Nd4j.exec(op);
assertEquals(e, out);
} catch (Throwable t){
failed.add(dt);
}
}
if(!failed.isEmpty()){
fail("Failed datatypes: " + failed.toString());
}
}
} }