parent
8f96f71f2b
commit
bbd59a3537
|
@ -46,6 +46,8 @@ namespace nd4j {
|
|||
max = &m2;
|
||||
}
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars: input and output data types must be the same");
|
||||
|
||||
int numBits = 8;
|
||||
if (block.getIArguments() && block.getIArguments()->size())
|
||||
numBits = INT_ARG(0);
|
||||
|
|
|
@ -40,8 +40,10 @@ namespace nd4j {
|
|||
|
||||
REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be"
|
||||
"%lld, but %lld occurs.", depth, max->lengthOf());
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars_per_channel: input and output data types must be the same");
|
||||
|
||||
int numBits = 8;
|
||||
if (block.getIArguments() && block.getIArguments()->size())
|
||||
numBits = INT_ARG(0);
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
#include <exceptions/graph_exception.h>
|
||||
#include <exceptions/unresolved_input_exception.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <exceptions/datatype_exception.h>
|
||||
#include <helpers/StringUtils.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -227,6 +229,15 @@ namespace nd4j {
|
|||
nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second);
|
||||
throw std::runtime_error("Expected vs provided shapes mismatch");
|
||||
}
|
||||
|
||||
/*
|
||||
* FIXME: we want to uncomment this eventually, and check data types equality
|
||||
//checking out data type equality
|
||||
if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) {
|
||||
std::string msg = "Provided array [" + StringUtils::valueToString<int>(pair.second) + "] has unexpected data type";
|
||||
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape));
|
||||
}
|
||||
*/
|
||||
}
|
||||
} else {
|
||||
auto fout = ctx.fastpath_out();
|
||||
|
@ -237,6 +248,7 @@ namespace nd4j {
|
|||
ctx.setOutputArray(idx, outArr, true);
|
||||
} else {
|
||||
auto array = fout[idx];
|
||||
// checking out shape equality
|
||||
if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) {
|
||||
auto eShape = ShapeUtils::shapeAsString(out);
|
||||
auto aShape = ShapeUtils::shapeAsString(array->shapeInfo());
|
||||
|
@ -247,6 +259,15 @@ namespace nd4j {
|
|||
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
|
||||
throw std::runtime_error("Expected vs provided shape mismatch");
|
||||
}
|
||||
|
||||
/*
|
||||
* FIXME: we want to uncomment this eventually, and check data types equality
|
||||
//checking out data type equality
|
||||
if (ArrayOptions::dataType(out) != array->dataType()) {
|
||||
std::string msg = "Provided array [" + StringUtils::valueToString<int>(idx) + "] has unexpected data type";
|
||||
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType());
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -426,6 +426,26 @@ TEST_F(JavaInteropTests, Test_FastPath_Validation_2) {
|
|||
ASSERT_NE(Status::OK(), status);
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, Test_FastPath_Validation_3) {
|
||||
auto x = NDArrayFactory::create<float>('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
|
||||
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
||||
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
|
||||
|
||||
auto min = NDArrayFactory::create<float>({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
|
||||
auto max = NDArrayFactory::create<float>({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||
|
||||
auto z = NDArrayFactory::create<double>('c', {3, 5});
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||
ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo());
|
||||
ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo());
|
||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||
|
||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||
ASSERT_ANY_THROW(op.execute(&ctx));
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, Test_empty_cast_1) {
|
||||
auto x = NDArrayFactory::create<bool>('c', {1, 0, 2});
|
||||
auto z = NDArrayFactory::create<Nd4jLong>('c', {1, 0, 2});
|
||||
|
|
Loading…
Reference in New Issue