diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index ccdf60f40..d9e48d1c1 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -89,8 +89,8 @@ namespace sd { else { //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); std::vector shape = {iD}; - mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext()); - variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext()); + mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); + variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); } @@ -104,7 +104,7 @@ namespace sd { const int restSize = x->lengthOf() / iD; - auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext()); + auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); xAffected.assign(xCast); const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index cb7f146da..7743255d1 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -40,7 +40,7 @@ namespace sd { * TArgs[0] - min for rng * TArgs[1] - max for rng */ - CUSTOM_OP_IMPL(randomuniform, -1, 1, true, 0, -1) { + CUSTOM_OP_IMPL(randomuniform, -1, 1, true, 0, -2) { // uniform distribution auto rng = block.randomGenerator(); auto dtype = DataType::FLOAT32; diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index bca23c1cc..2932bc455 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -61,6 +61,29 @@ DECLARE_TYPES(reshape) { ->setSameMode(true); } + +bool handleOptionalOrder(std::vector &reshapeArgs, char &ordering){ + if(reshapeArgs.size()>0){ + //check if any optional negative ordering value is passed + auto optional = reshapeArgs[0]; + if(optional < 0){ + optional = abs(optional); + //check if passed option is allowed. (-1 -> dynamic shape) + // in that case we will return back + if(optional == 1 ) return true; + //in this case it should obey allowed orderings + if (optional != 'c' && optional != 'f' ) return false; + reshapeArgs.erase( reshapeArgs.begin()); + //ordering was passed and ok. let's assign + ordering = optional; + } + + } + //skipped + return true; +} + + DECLARE_SHAPE_FN(reshape) { const auto x = INPUT_VARIABLE(0); @@ -78,26 +101,14 @@ DECLARE_SHAPE_FN(reshape) { */ if (block.width() == 1) { reshapeArgs = *block.getIArguments(); - if (!reshapeArgs.empty()) { - char potentialOrdering = (char)-reshapeArgs[0]; - orderNew = potentialOrdering; - if (potentialOrdering != 'c' && potentialOrdering != 'f') { + if(!handleOptionalOrder(reshapeArgs, orderNew)){ throw std::runtime_error( "reshape:: Value passed in must be -99 or -102 for the ordering if " "an int array is present. -99 represents c ordering and -102 " "represents f ordering. This number is negative for the long array " "case to flag the difference between an ordering and a dimension " "being specified."); - } - - nd4j_debug("Reshape Ordering is %c int ordering is %d\n", orderNew, - -reshapeArgs[0]); - - if (orderNew == 'c' || orderNew == 'f') - reshapeArgs.erase( - reshapeArgs - .begin()); // remove first element being order in this case - } + }; } else { reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); if (block.numI() > 0) { diff --git a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp index f83e61eb3..8afef3701 100644 --- a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp +++ b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp @@ -227,6 +227,7 @@ TEST_F(SparseUtilsTest, RavelIndices_Test) { } shape[2] = 30; + delete[] shapeInfoBuffer; shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape); try {