[WIP] size etc (#155)

* one test for size

Signed-off-by: raver119 <raver119@gmail.com>

* - few tests for size op
- size/rank/size_at ops now use p instead of assign

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 12:31:12 +03:00 committed by GitHub
parent e22a2c93ff
commit 729dc5e879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 83 additions and 6 deletions

View File

@ -87,6 +87,10 @@ namespace nd4j {
std::map<int, std::vector<nd4j::DataType>> _outputTypes;
std::map<int, std::vector<nd4j::DataType>> _inputTypes;
// field for ops that allow data type override at runtime
bool _dtypeOverride = false;
bool checkDataTypesMatch(nd4j::DataType needle, std::vector<nd4j::DataType> &haystack) const;
public:
// default constructor
@ -164,6 +168,7 @@ namespace nd4j {
OpDescriptor* setAllowedOutputTypes(int index, nd4j::DataType dtype);
OpDescriptor* setAllowedInputTypes(nd4j::DataType dtype);
OpDescriptor* setAllowedOutputTypes(nd4j::DataType dtype);
OpDescriptor* allowOverride(bool reallyAllow);
OpDescriptor* setSameMode(bool reallySame);
OpDescriptor* setInputType(int idx, nd4j::DataType dtype);
OpDescriptor* setOutputType(int idx, nd4j::DataType dtype);

View File

@ -31,7 +31,8 @@ namespace nd4j {
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
output->assign(input->rankOf());
output->p(0, input->rankOf());
output->syncToDevice();
return Status::OK();
}
@ -43,7 +44,8 @@ namespace nd4j {
DECLARE_TYPES(rank) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INTS});
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
->allowOverride(true);
}
}
}

View File

@ -31,7 +31,8 @@ namespace nd4j {
REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar");
output->assign(input->lengthOf());
output->p(0, input->lengthOf());
output->syncToDevice();
return Status::OK();
}
@ -42,7 +43,8 @@ namespace nd4j {
DECLARE_TYPES(size) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INTS});
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
->allowOverride(true);
}
}
}

View File

@ -35,7 +35,8 @@ namespace nd4j {
REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank")
output->assign(input->sizeAt(dim));
output->p(0, input->sizeAt(dim));
output->syncToDevice();
return Status::OK();
}
@ -47,7 +48,8 @@ namespace nd4j {
DECLARE_TYPES(size_at) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(DataType::INT64);
->setAllowedOutputTypes(DataType::INT64)
->allowOverride(true);
}
}
}

View File

@ -166,6 +166,11 @@ namespace nd4j {
return this;
}
OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) {
_dtypeOverride = allowOverride;
return this;
}
OpDescriptor* OpDescriptor::setAllowedInputTypes(const nd4j::DataType dtype) {
_allowedIns.clear();
_allowedIns.emplace_back(dtype);

View File

@ -55,3 +55,15 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
delete result;
}
TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
auto z = NDArrayFactory::create<float>(0.0f);
auto e = NDArrayFactory::create<float>(3.0f);
nd4j::ops::size op;
auto status = op.execute({&x}, {&z}, {}, {}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}

View File

@ -1189,6 +1189,22 @@ TEST_F(JavaInteropTests, test_ismax_view) {
delete t;
}
TEST_F(JavaInteropTests, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
auto z = NDArrayFactory::create<float>(0.0f);
auto e = NDArrayFactory::create<float>(3.0f);
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::size op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
/*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");

View File

@ -9272,6 +9272,7 @@ public static final int PREALLOC_SIZE = 33554432;
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);

View File

@ -11557,6 +11557,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);

View File

@ -44,6 +44,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
@ -641,4 +642,34 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(result1, result2);
}
@Test
public void testSizeTypes(){
List<DataType> failed = new ArrayList<>();
for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE,
DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE,
DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16}) {
INDArray in = Nd4j.create(DataType.FLOAT, 100);
INDArray out = Nd4j.scalar(dt, 0);
INDArray e = Nd4j.scalar(dt, 100);
DynamicCustomOp op = DynamicCustomOp.builder("size")
.addInputs(in)
.addOutputs(out)
.build();
try {
Nd4j.exec(op);
assertEquals(e, out);
} catch (Throwable t){
failed.add(dt);
}
}
if(!failed.isEmpty()){
fail("Failed datatypes: " + failed.toString());
}
}
}