[WIP] size etc (#155)
* one test for size Signed-off-by: raver119 <raver119@gmail.com> * - few tests for size op - size/rank/size_at ops now use p instead of assign Signed-off-by: raver119 <raver119@gmail.com>master
parent
e22a2c93ff
commit
729dc5e879
|
@ -87,6 +87,10 @@ namespace nd4j {
|
|||
std::map<int, std::vector<nd4j::DataType>> _outputTypes;
|
||||
std::map<int, std::vector<nd4j::DataType>> _inputTypes;
|
||||
|
||||
|
||||
// field for ops that allow data type override at runtime
|
||||
bool _dtypeOverride = false;
|
||||
|
||||
bool checkDataTypesMatch(nd4j::DataType needle, std::vector<nd4j::DataType> &haystack) const;
|
||||
public:
|
||||
// default constructor
|
||||
|
@ -164,6 +168,7 @@ namespace nd4j {
|
|||
OpDescriptor* setAllowedOutputTypes(int index, nd4j::DataType dtype);
|
||||
OpDescriptor* setAllowedInputTypes(nd4j::DataType dtype);
|
||||
OpDescriptor* setAllowedOutputTypes(nd4j::DataType dtype);
|
||||
OpDescriptor* allowOverride(bool reallyAllow);
|
||||
OpDescriptor* setSameMode(bool reallySame);
|
||||
OpDescriptor* setInputType(int idx, nd4j::DataType dtype);
|
||||
OpDescriptor* setOutputType(int idx, nd4j::DataType dtype);
|
||||
|
|
|
@ -31,7 +31,8 @@ namespace nd4j {
|
|||
|
||||
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
||||
|
||||
output->assign(input->rankOf());
|
||||
output->p(0, input->rankOf());
|
||||
output->syncToDevice();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -43,7 +44,8 @@ namespace nd4j {
|
|||
DECLARE_TYPES(rank) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_INTS});
|
||||
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
|
||||
->allowOverride(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,8 @@ namespace nd4j {
|
|||
|
||||
REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar");
|
||||
|
||||
output->assign(input->lengthOf());
|
||||
output->p(0, input->lengthOf());
|
||||
output->syncToDevice();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -42,7 +43,8 @@ namespace nd4j {
|
|||
DECLARE_TYPES(size) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_INTS});
|
||||
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS})
|
||||
->allowOverride(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,8 @@ namespace nd4j {
|
|||
|
||||
REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank")
|
||||
|
||||
output->assign(input->sizeAt(dim));
|
||||
output->p(0, input->sizeAt(dim));
|
||||
output->syncToDevice();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -47,7 +48,8 @@ namespace nd4j {
|
|||
DECLARE_TYPES(size_at) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes(DataType::INT64);
|
||||
->setAllowedOutputTypes(DataType::INT64)
|
||||
->allowOverride(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -166,6 +166,11 @@ namespace nd4j {
|
|||
return this;
|
||||
}
|
||||
|
||||
OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) {
|
||||
_dtypeOverride = allowOverride;
|
||||
return this;
|
||||
}
|
||||
|
||||
OpDescriptor* OpDescriptor::setAllowedInputTypes(const nd4j::DataType dtype) {
|
||||
_allowedIns.clear();
|
||||
_allowedIns.emplace_back(dtype);
|
||||
|
|
|
@ -54,4 +54,16 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
|
|||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
||||
auto z = NDArrayFactory::create<float>(0.0f);
|
||||
auto e = NDArrayFactory::create<float>(3.0f);
|
||||
|
||||
nd4j::ops::size op;
|
||||
auto status = op.execute({&x}, {&z}, {}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
|
@ -1189,6 +1189,22 @@ TEST_F(JavaInteropTests, test_ismax_view) {
|
|||
delete t;
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, test_size_dtype_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
||||
auto z = NDArrayFactory::create<float>(0.0f);
|
||||
auto e = NDArrayFactory::create<float>(3.0f);
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||
|
||||
nd4j::ops::size op;
|
||||
auto status = op.execute(&ctx);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||
|
|
|
@ -9272,6 +9272,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
|
||||
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
|
||||
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||
|
|
|
@ -11557,6 +11557,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow);
|
||||
public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame);
|
||||
public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||
public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype);
|
||||
|
|
|
@ -44,6 +44,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
|||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -641,4 +642,34 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
|
||||
assertEquals(result1, result2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSizeTypes(){
|
||||
List<DataType> failed = new ArrayList<>();
|
||||
for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE,
|
||||
DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE,
|
||||
DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16}) {
|
||||
|
||||
INDArray in = Nd4j.create(DataType.FLOAT, 100);
|
||||
INDArray out = Nd4j.scalar(dt, 0);
|
||||
INDArray e = Nd4j.scalar(dt, 100);
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("size")
|
||||
.addInputs(in)
|
||||
.addOutputs(out)
|
||||
.build();
|
||||
|
||||
try {
|
||||
Nd4j.exec(op);
|
||||
|
||||
assertEquals(e, out);
|
||||
} catch (Throwable t){
|
||||
failed.add(dt);
|
||||
}
|
||||
}
|
||||
|
||||
if(!failed.isEmpty()){
|
||||
fail("Failed datatypes: " + failed.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue