reshape: fix optional order case that failed tests

Signed-off-by: AbdelRauf <rauf@konduit.ai>
master
AbdelRauf 2021-02-23 21:11:30 +01:00
parent 375efff2e4
commit 1550cebcd5
1 changed files with 25 additions and 14 deletions

View File

@ -61,6 +61,29 @@ DECLARE_TYPES(reshape) {
->setSameMode(true); ->setSameMode(true);
} }
bool handleOptionalOrder(std::vector<int> &reshapeArgs, char &ordering){
if(reshapeArgs.size()>0){
//check if any optional negative ordering value is passed
auto optional = reshapeArgs[0];
if(optional < 0){
optional = abs(optional);
//check if passed option is allowed. (-1 -> dynamic shape)
// in that case we will return back
if(optional == 1 ) return true;
//in this case it should obey allowed orderings
if (optional != 'c' && optional != 'f' ) return false;
reshapeArgs.erase( reshapeArgs.begin());
//ordering was passed and ok. let's assign
ordering = optional;
}
}
//skipped
return true;
}
DECLARE_SHAPE_FN(reshape) { DECLARE_SHAPE_FN(reshape) {
const auto x = INPUT_VARIABLE(0); const auto x = INPUT_VARIABLE(0);
@ -78,26 +101,14 @@ DECLARE_SHAPE_FN(reshape) {
*/ */
if (block.width() == 1) { if (block.width() == 1) {
reshapeArgs = *block.getIArguments(); reshapeArgs = *block.getIArguments();
if (!reshapeArgs.empty()) { if(!handleOptionalOrder(reshapeArgs, orderNew)){
char potentialOrdering = (char)-reshapeArgs[0];
orderNew = potentialOrdering;
if (potentialOrdering != 'c' && potentialOrdering != 'f') {
throw std::runtime_error( throw std::runtime_error(
"reshape:: Value passed in must be -99 or -102 for the ordering if " "reshape:: Value passed in must be -99 or -102 for the ordering if "
"an int array is present. -99 represents c ordering and -102 " "an int array is present. -99 represents c ordering and -102 "
"represents f ordering. This number is negative for the long array " "represents f ordering. This number is negative for the long array "
"case to flag the difference between an ordering and a dimension " "case to flag the difference between an ordering and a dimension "
"being specified."); "being specified.");
} };
nd4j_debug("Reshape Ordering is %c int ordering is %d\n", orderNew,
-reshapeArgs[0]);
if (orderNew == 'c' || orderNew == 'f')
reshapeArgs.erase(
reshapeArgs
.begin()); // remove first element being order in this case
}
} else { } else {
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>(); reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
if (block.numI() > 0) { if (block.numI() > 0) {