[WIP] size etc (#155)
* one test for size Signed-off-by: raver119 <raver119@gmail.com> * - few tests for size op - size/rank/size_at ops now use p instead of assign Signed-off-by: raver119 <raver119@gmail.com>master
parent
e22a2c93ff
commit
729dc5e879
|
@ -87,6 +87,10 @@ namespace nd4j {
|
||||||
std::map<int, std::vector<nd4j::DataType>> _outputTypes;
|
std::map<int, std::vector<nd4j::DataType>> _outputTypes;
|
||||||
std::map<int, std::vector<nd4j::DataType>> _inputTypes;
|
std::map<int, std::vector<nd4j::DataType>> _inputTypes;
|
||||||
|
|
||||||
|
|
||||||
|
// field for ops that allow data type override at runtime
|
||||||
|
bool _dtypeOverride = false;
|
||||||
|
|
||||||
bool checkDataTypesMatch(nd4j::DataType needle, std::vector<nd4j::DataType> &haystack) const;
|
bool checkDataTypesMatch(nd4j::DataType needle, std::vector<nd4j::DataType> &haystack) const;
|
||||||
public:
|
public:
|
||||||
// default constructor
|
// default constructor
|
||||||
|
@ -164,6 +168,7 @@ namespace nd4j {
|
||||||
OpDescriptor* setAllowedOutputTypes(int index, nd4j::DataType dtype);
|
OpDescriptor* setAllowedOutputTypes(int index, nd4j::DataType dtype);
|
||||||
OpDescriptor* setAllowedInputTypes(nd4j::DataType dtype);
|
OpDescriptor* setAllowedInputTypes(nd4j::DataType dtype);
|
||||||
OpDescriptor* setAllowedOutputTypes(nd4j::DataType dtype);
|
OpDescriptor* setAllowedOutputTypes(nd4j::DataType dtype);
|
||||||
|
OpDescriptor* allowOverride(bool reallyAllow);
|
||||||
OpDescriptor* setSameMode(bool reallySame);
|
OpDescriptor* setSameMode(bool reallySame);
|
||||||
OpDescriptor* setInputType(int idx, nd4j::DataType dtype);
|
OpDescriptor* setInputType(int idx, nd4j::DataType dtype);
|
||||||
OpDescriptor* setOutputType(int idx, nd4j::DataType dtype);
|
OpDescriptor* setOutputType(int idx, nd4j::DataType dtype);
|
||||||
|
|
|
@ -31,7 +31,8 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
||||||
|
|
||||||
output->assign(input->rankOf());
|
output->p(0, input->rankOf());
|
||||||
|
output->syncToDevice();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -43,7 +44,8 @@ namespace nd4j {
|
||||||
DECLARE_TYPES(rank) {
|
DECLARE_TYPES(rank) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_INTS});
|
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
|
||||||
|
->allowOverride(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,8 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar");
|
REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar");
|
||||||
|
|
||||||
output->assign(input->lengthOf());
|
output->p(0, input->lengthOf());
|
||||||
|
output->syncToDevice();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -42,7 +43,8 @@ namespace nd4j {
|
||||||
DECLARE_TYPES(size) {
|
DECLARE_TYPES(size) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_INTS});
|
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
|
||||||
|
->allowOverride(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,8 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank")
|
REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank")
|
||||||
|
|
||||||
output->assign(input->sizeAt(dim));
|
output->p(0, input->sizeAt(dim));
|
||||||
|
output->syncToDevice();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -47,7 +48,8 @@ namespace nd4j {
|
||||||
DECLARE_TYPES(size_at) {
|
DECLARE_TYPES(size_at) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(DataType::INT64);
|
->setAllowedOutputTypes(DataType::INT64)
|
||||||
|
->allowOverride(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -166,6 +166,11 @@ namespace nd4j {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) {
|
||||||
|
_dtypeOverride = allowOverride;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
OpDescriptor* OpDescriptor::setAllowedInputTypes(const nd4j::DataType dtype) {
|
OpDescriptor* OpDescriptor::setAllowedInputTypes(const nd4j::DataType dtype) {
|
||||||
_allowedIns.clear();
|
_allowedIns.clear();
|
||||||
_allowedIns.emplace_back(dtype);
|
_allowedIns.emplace_back(dtype);
|
||||||
|
|
|
@ -55,3 +55,15 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
||||||
|
auto z = NDArrayFactory::create<float>(0.0f);
|
||||||
|
auto e = NDArrayFactory::create<float>(3.0f);
|
||||||
|
|
||||||
|
nd4j::ops::size op;
|
||||||
|
auto status = op.execute({&x}, {&z}, {}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
|
@ -1189,6 +1189,22 @@ TEST_F(JavaInteropTests, test_ismax_view) {
|
||||||
delete t;
|
delete t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(JavaInteropTests, test_size_dtype_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
||||||
|
auto z = NDArrayFactory::create<float>(0.0f);
|
||||||
|
auto e = NDArrayFactory::create<float>(3.0f);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||||
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
|
nd4j::ops::size op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||||
|
|
|
@ -9272,6 +9272,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
|
||||||
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
|
||||||
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
|
||||||
|
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
|
||||||
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
|
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
|
||||||
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||||
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||||
|
|
|
@ -11557,6 +11557,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
|
||||||
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
|
||||||
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
|
||||||
|
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
|
||||||
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
|
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
|
||||||
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||||
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);
|
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||||
|
|
|
@ -44,6 +44,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
@ -641,4 +642,34 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
|
|
||||||
assertEquals(result1, result2);
|
assertEquals(result1, result2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSizeTypes(){
|
||||||
|
List<DataType> failed = new ArrayList<>();
|
||||||
|
for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE,
|
||||||
|
DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE,
|
||||||
|
DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16}) {
|
||||||
|
|
||||||
|
INDArray in = Nd4j.create(DataType.FLOAT, 100);
|
||||||
|
INDArray out = Nd4j.scalar(dt, 0);
|
||||||
|
INDArray e = Nd4j.scalar(dt, 100);
|
||||||
|
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("size")
|
||||||
|
.addInputs(in)
|
||||||
|
.addOutputs(out)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try {
|
||||||
|
Nd4j.exec(op);
|
||||||
|
|
||||||
|
assertEquals(e, out);
|
||||||
|
} catch (Throwable t){
|
||||||
|
failed.add(dt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!failed.isEmpty()){
|
||||||
|
fail("Failed datatypes: " + failed.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue