commit
405307dea3
|
@ -89,8 +89,8 @@ namespace sd {
|
||||||
else {
|
else {
|
||||||
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||||
std::vector<Nd4jLong> shape = {iD};
|
std::vector<Nd4jLong> shape = {iD};
|
||||||
mean = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||||
variance = NDArrayFactory::create_(scale->ordering(), shape, sd::DataType::FLOAT32, block.launchContext());
|
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ namespace sd {
|
||||||
|
|
||||||
const int restSize = x->lengthOf() / iD;
|
const int restSize = x->lengthOf() / iD;
|
||||||
|
|
||||||
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, sd::DataType::FLOAT32, block.launchContext());
|
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
|
||||||
xAffected.assign(xCast);
|
xAffected.assign(xCast);
|
||||||
|
|
||||||
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace sd {
|
||||||
* TArgs[0] - min for rng
|
* TArgs[0] - min for rng
|
||||||
* TArgs[1] - max for rng
|
* TArgs[1] - max for rng
|
||||||
*/
|
*/
|
||||||
CUSTOM_OP_IMPL(randomuniform, -1, 1, true, 0, -1) {
|
CUSTOM_OP_IMPL(randomuniform, -1, 1, true, 0, -2) {
|
||||||
// uniform distribution
|
// uniform distribution
|
||||||
auto rng = block.randomGenerator();
|
auto rng = block.randomGenerator();
|
||||||
auto dtype = DataType::FLOAT32;
|
auto dtype = DataType::FLOAT32;
|
||||||
|
|
|
@ -61,6 +61,29 @@ DECLARE_TYPES(reshape) {
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool handleOptionalOrder(std::vector<int> &reshapeArgs, char &ordering){
|
||||||
|
if(reshapeArgs.size()>0){
|
||||||
|
//check if any optional negative ordering value is passed
|
||||||
|
auto optional = reshapeArgs[0];
|
||||||
|
if(optional < 0){
|
||||||
|
optional = abs(optional);
|
||||||
|
//check if passed option is allowed. (-1 -> dynamic shape)
|
||||||
|
// in that case we will return back
|
||||||
|
if(optional == 1 ) return true;
|
||||||
|
//in this case it should obey allowed orderings
|
||||||
|
if (optional != 'c' && optional != 'f' ) return false;
|
||||||
|
reshapeArgs.erase( reshapeArgs.begin());
|
||||||
|
//ordering was passed and ok. let's assign
|
||||||
|
ordering = optional;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
//skipped
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshape) {
|
DECLARE_SHAPE_FN(reshape) {
|
||||||
|
|
||||||
const auto x = INPUT_VARIABLE(0);
|
const auto x = INPUT_VARIABLE(0);
|
||||||
|
@ -78,26 +101,14 @@ DECLARE_SHAPE_FN(reshape) {
|
||||||
*/
|
*/
|
||||||
if (block.width() == 1) {
|
if (block.width() == 1) {
|
||||||
reshapeArgs = *block.getIArguments();
|
reshapeArgs = *block.getIArguments();
|
||||||
if (!reshapeArgs.empty()) {
|
if(!handleOptionalOrder(reshapeArgs, orderNew)){
|
||||||
char potentialOrdering = (char)-reshapeArgs[0];
|
|
||||||
orderNew = potentialOrdering;
|
|
||||||
if (potentialOrdering != 'c' && potentialOrdering != 'f') {
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"reshape:: Value passed in must be -99 or -102 for the ordering if "
|
"reshape:: Value passed in must be -99 or -102 for the ordering if "
|
||||||
"an int array is present. -99 represents c ordering and -102 "
|
"an int array is present. -99 represents c ordering and -102 "
|
||||||
"represents f ordering. This number is negative for the long array "
|
"represents f ordering. This number is negative for the long array "
|
||||||
"case to flag the difference between an ordering and a dimension "
|
"case to flag the difference between an ordering and a dimension "
|
||||||
"being specified.");
|
"being specified.");
|
||||||
}
|
};
|
||||||
|
|
||||||
nd4j_debug("Reshape Ordering is %c int ordering is %d\n", orderNew,
|
|
||||||
-reshapeArgs[0]);
|
|
||||||
|
|
||||||
if (orderNew == 'c' || orderNew == 'f')
|
|
||||||
reshapeArgs.erase(
|
|
||||||
reshapeArgs
|
|
||||||
.begin()); // remove first element being order in this case
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
|
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
|
||||||
if (block.numI() > 0) {
|
if (block.numI() > 0) {
|
||||||
|
|
|
@ -227,6 +227,7 @@ TEST_F(SparseUtilsTest, RavelIndices_Test) {
|
||||||
}
|
}
|
||||||
|
|
||||||
shape[2] = 30;
|
shape[2] = 30;
|
||||||
|
delete[] shapeInfoBuffer;
|
||||||
shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
|
shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
Loading…
Reference in New Issue