reshape: fix optional order case that failed tests
Signed-off-by: AbdelRauf <rauf@konduit.ai>master
parent
375efff2e4
commit
1550cebcd5
|
@ -61,6 +61,29 @@ DECLARE_TYPES(reshape) {
|
|||
->setSameMode(true);
|
||||
}
|
||||
|
||||
|
||||
bool handleOptionalOrder(std::vector<int> &reshapeArgs, char &ordering){
|
||||
if(reshapeArgs.size()>0){
|
||||
//check if any optional negative ordering value is passed
|
||||
auto optional = reshapeArgs[0];
|
||||
if(optional < 0){
|
||||
optional = abs(optional);
|
||||
//check if passed option is allowed. (-1 -> dynamic shape)
|
||||
// in that case we will return back
|
||||
if(optional == 1 ) return true;
|
||||
//in this case it should obey allowed orderings
|
||||
if (optional != 'c' && optional != 'f' ) return false;
|
||||
reshapeArgs.erase( reshapeArgs.begin());
|
||||
//ordering was passed and ok. let's assign
|
||||
ordering = optional;
|
||||
}
|
||||
|
||||
}
|
||||
//skipped
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(reshape) {
|
||||
|
||||
const auto x = INPUT_VARIABLE(0);
|
||||
|
@ -78,26 +101,14 @@ DECLARE_SHAPE_FN(reshape) {
|
|||
*/
|
||||
if (block.width() == 1) {
|
||||
reshapeArgs = *block.getIArguments();
|
||||
if (!reshapeArgs.empty()) {
|
||||
char potentialOrdering = (char)-reshapeArgs[0];
|
||||
orderNew = potentialOrdering;
|
||||
if (potentialOrdering != 'c' && potentialOrdering != 'f') {
|
||||
if(!handleOptionalOrder(reshapeArgs, orderNew)){
|
||||
throw std::runtime_error(
|
||||
"reshape:: Value passed in must be -99 or -102 for the ordering if "
|
||||
"an int array is present. -99 represents c ordering and -102 "
|
||||
"represents f ordering. This number is negative for the long array "
|
||||
"case to flag the difference between an ordering and a dimension "
|
||||
"being specified.");
|
||||
}
|
||||
|
||||
nd4j_debug("Reshape Ordering is %c int ordering is %d\n", orderNew,
|
||||
-reshapeArgs[0]);
|
||||
|
||||
if (orderNew == 'c' || orderNew == 'f')
|
||||
reshapeArgs.erase(
|
||||
reshapeArgs
|
||||
.begin()); // remove first element being order in this case
|
||||
}
|
||||
};
|
||||
} else {
|
||||
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
|
||||
if (block.numI() > 0) {
|
||||
|
|
Loading…
Reference in New Issue