fusedbatchnorm: fix type mismatching error

Signed-off-by: AbdelRauf <rauf@konduit.ai>
master
AbdelRauf 2021-02-23 21:16:07 +01:00
parent c86300373e
commit 426e28640a
1 changed files with 3 additions and 3 deletions

View File

@ -89,8 +89,8 @@ namespace sd {
else { else {
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
std::vector<Nd4jLong> shape = {iD}; std::vector<Nd4jLong> shape = {iD};
mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext()); mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext()); variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
} }
@ -104,7 +104,7 @@ namespace sd {
const int restSize = x->lengthOf() / iD; const int restSize = x->lengthOf() / iD;
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext()); auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
xAffected.assign(xCast); xAffected.assign(xCast);
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;