fake quant dtype validation fix (#60)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-19 12:53:52 +03:00 committed by GitHub
parent 8f96f71f2b
commit bbd59a3537
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 1 deletions

View File

@ -46,6 +46,8 @@ namespace nd4j {
max = &m2;
}
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars: input and output data types must be the same");
int numBits = 8;
if (block.getIArguments() && block.getIArguments()->size())
numBits = INT_ARG(0);

View File

@ -40,8 +40,10 @@ namespace nd4j {
REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be"
"%lld, but %lld occurs.", depth, max->lengthOf());
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars_per_channel: input and output data types must be the same");
int numBits = 8;
if (block.getIArguments() && block.getIArguments()->size())
numBits = INT_ARG(0);

View File

@ -25,6 +25,8 @@
#include <exceptions/graph_exception.h>
#include <exceptions/unresolved_input_exception.h>
#include <ops/declarable/OpRegistrator.h>
#include <exceptions/datatype_exception.h>
#include <helpers/StringUtils.h>
namespace nd4j {
namespace ops {
@ -227,6 +229,15 @@ namespace nd4j {
nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second);
throw std::runtime_error("Expected vs provided shapes mismatch");
}
/*
* FIXME: we want to uncomment this eventually, and check data types equality
//checking out data type equality
if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) {
std::string msg = "Provided array [" + StringUtils::valueToString<int>(pair.second) + "] has unexpected data type";
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape));
}
*/
}
} else {
auto fout = ctx.fastpath_out();
@ -237,6 +248,7 @@ namespace nd4j {
ctx.setOutputArray(idx, outArr, true);
} else {
auto array = fout[idx];
// checking out shape equality
if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) {
auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(array->shapeInfo());
@ -247,6 +259,15 @@ namespace nd4j {
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
throw std::runtime_error("Expected vs provided shape mismatch");
}
/*
* FIXME: we want to uncomment this eventually, and check data types equality
//checking out data type equality
if (ArrayOptions::dataType(out) != array->dataType()) {
std::string msg = "Provided array [" + StringUtils::valueToString<int>(idx) + "] has unexpected data type";
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType());
}
*/
}
}
}

View File

@ -426,6 +426,26 @@ TEST_F(JavaInteropTests, Test_FastPath_Validation_2) {
ASSERT_NE(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_FastPath_Validation_3) {
auto x = NDArrayFactory::create<float>('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
auto min = NDArrayFactory::create<float>({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
auto max = NDArrayFactory::create<float>({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
auto z = NDArrayFactory::create<double>('c', {3, 5});
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo());
ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
ASSERT_ANY_THROW(op.execute(&ctx));
}
TEST_F(JavaInteropTests, Test_empty_cast_1) {
auto x = NDArrayFactory::create<bool>('c', {1, 0, 2});
auto z = NDArrayFactory::create<Nd4jLong>('c', {1, 0, 2});