diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp index e19c030c4..6d24827e5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp @@ -46,6 +46,8 @@ namespace nd4j { max = &m2; } auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars: input and output data types must be the same"); + int numBits = 8; if (block.getIArguments() && block.getIArguments()->size()) numBits = INT_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index e5873d9dd..5874d2f81 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -40,8 +40,10 @@ namespace nd4j { REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be" "%lld, but %lld occurs.", depth, max->lengthOf()); - auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars_per_channel: input and output data types must be the same"); + int numBits = 8; if (block.getIArguments() && block.getIArguments()->size()) numBits = INT_ARG(0); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 5ee19b007..8c65ac25e 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -25,6 +25,8 @@ #include #include #include +#include +#include namespace nd4j { namespace ops { @@ -227,6 +229,15 @@ namespace nd4j { nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second); throw std::runtime_error("Expected vs provided shapes mismatch"); } + + /* + * FIXME: we want to uncomment this eventually, and check data types equality + //checking out data type equality + if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) { + std::string msg = "Provided array [" + StringUtils::valueToString(pair.second) + "] has unexpected data type"; + throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape)); + } + */ } } else { auto fout = ctx.fastpath_out(); @@ -237,6 +248,7 @@ namespace nd4j { ctx.setOutputArray(idx, outArr, true); } else { auto array = fout[idx]; + // checking out shape equality if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) { auto eShape = ShapeUtils::shapeAsString(out); auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); @@ -247,6 +259,15 @@ namespace nd4j { nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx); throw std::runtime_error("Expected vs provided shape mismatch"); } + + /* + * FIXME: we want to uncomment this eventually, and check data types equality + //checking out data type equality + if (ArrayOptions::dataType(out) != array->dataType()) { + std::string msg = "Provided array [" + StringUtils::valueToString(idx) + "] has unexpected data type"; + throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType()); + } + */ } } } diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index aa75ea1ab..a6ca56fd4 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -426,6 +426,26 @@ TEST_F(JavaInteropTests, Test_FastPath_Validation_2) { ASSERT_NE(Status::OK(), status); } +TEST_F(JavaInteropTests, Test_FastPath_Validation_3) { + auto x = NDArrayFactory::create('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + + auto min = NDArrayFactory::create({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + auto max = NDArrayFactory::create({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + auto z = NDArrayFactory::create('c', {3, 5}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo()); + ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + ASSERT_ANY_THROW(op.execute(&ctx)); +} + TEST_F(JavaInteropTests, Test_empty_cast_1) { auto x = NDArrayFactory::create('c', {1, 0, 2}); auto z = NDArrayFactory::create('c', {1, 0, 2});