Small fixes (#140)

* Allow scalar op result array auto allocation

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Don't swallow underlying exception for calculateOutputShape execution failures

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignore for known keras failure

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-12-21 17:00:46 +11:00 committed by GitHub
parent 495256c827
commit ce02b6fae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 5 deletions

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -255,10 +256,8 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
}
}
@Test
@Test @Ignore("AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441")
public void ReshapeEmbeddingConcatTest() throws Exception{
//TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
ComputationGraphConfiguration config =
new KerasModel().modelBuilder().modelJsonInputStream(is)

View File

@ -1293,6 +1293,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
// validateDataType(Nd4j.dataType(), op);
if(op.z() == null){
switch (op.getOpType()) {
case SCALAR:
op.setZ(op.x().ulike());
break;
case SCALAR_BOOL:
op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape()));
break;
default:
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
}
}
if (op.x().length() != op.z().length())
throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: ["
+ Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != ["
@ -2280,7 +2293,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
op.addOutputArgument(Nd4j.create(shape));
} catch (Exception e) {
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e);
}
}

View File

@ -670,6 +670,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
//validateDataType(Nd4j.dataType(), op);
if(op.z() == null){
switch (op.getOpType()) {
case SCALAR:
op.setZ(op.x().ulike());
break;
case SCALAR_BOOL:
op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape()));
break;
default:
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
}
}
if (op.x().length() != op.z().length())
throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " +
"x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = ["
@ -1689,7 +1702,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} catch (ND4JIllegalStateException e){
throw e;
} catch (Exception e) {
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e);
}
}

View File

@ -66,6 +66,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
import org.nd4j.linalg.api.ops.impl.reduce3.*;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
@ -8164,6 +8165,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(e, z);
}
@Test
public void testScalarEqualsNoResult(){
INDArray out = Nd4j.exec(new ScalarEquals(Nd4j.createFromArray(-2, -1, 0, 1, 2), null, 0));
INDArray exp = Nd4j.createFromArray(false, false, true, false, false);
assertEquals(exp, out);
}
@Override
public char ordering() {
return 'c';