fusedbatchnorm: fix type mismatching error
Signed-off-by: AbdelRauf <rauf@konduit.ai>master
parent
c86300373e
commit
426e28640a
|
@ -89,8 +89,8 @@ namespace sd {
|
|||
else {
|
||||
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||
std::vector<Nd4jLong> shape = {iD};
|
||||
mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
||||
variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
||||
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||
}
|
||||
|
||||
|
||||
|
@ -104,7 +104,7 @@ namespace sd {
|
|||
|
||||
const int restSize = x->lengthOf() / iD;
|
||||
|
||||
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext());
|
||||
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
|
||||
xAffected.assign(xCast);
|
||||
|
||||
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
||||
|
|
Loading…
Reference in New Issue