diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index ccdf60f40..d9e48d1c1 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -89,8 +89,8 @@ namespace sd { else { //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); std::vector shape = {iD}; - mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext()); - variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext()); + mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); + variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); } @@ -104,7 +104,7 @@ namespace sd { const int restSize = x->lengthOf() / iD; - auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext()); + auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); xAffected.assign(xCast); const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;