fusedbatchnorm: fix type mismatching error
Signed-off-by: AbdelRauf <rauf@konduit.ai>master
parent
c86300373e
commit
426e28640a
|
@ -89,8 +89,8 @@ namespace sd {
|
||||||
else {
|
else {
|
||||||
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||||
std::vector<Nd4jLong> shape = {iD};
|
std::vector<Nd4jLong> shape = {iD};
|
||||||
mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||||
variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ namespace sd {
|
||||||
|
|
||||||
const int restSize = x->lengthOf() / iD;
|
const int restSize = x->lengthOf() / iD;
|
||||||
|
|
||||||
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext());
|
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
|
||||||
xAffected.assign(xCast);
|
xAffected.assign(xCast);
|
||||||
|
|
||||||
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
||||||
|
|
Loading…
Reference in New Issue