Small fixes (#140)
* Allow scalar op result array auto allocation Signed-off-by: AlexDBlack <blacka101@gmail.com> * Don't swallow underlying exception for calculateOutputShape execution failures Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ignore for known keras failure Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
495256c827
commit
ce02b6fae7
|
@ -27,6 +27,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModel;
|
|||
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -255,10 +256,8 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test @Ignore("AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441")
|
||||
public void ReshapeEmbeddingConcatTest() throws Exception{
|
||||
//TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
|
||||
|
||||
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
||||
ComputationGraphConfiguration config =
|
||||
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
||||
|
|
|
@ -1293,6 +1293,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
// validateDataType(Nd4j.dataType(), op);
|
||||
|
||||
if(op.z() == null){
|
||||
switch (op.getOpType()) {
|
||||
case SCALAR:
|
||||
op.setZ(op.x().ulike());
|
||||
break;
|
||||
case SCALAR_BOOL:
|
||||
op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape()));
|
||||
break;
|
||||
default:
|
||||
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
|
||||
}
|
||||
}
|
||||
|
||||
if (op.x().length() != op.z().length())
|
||||
throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: ["
|
||||
+ Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != ["
|
||||
|
@ -2280,7 +2293,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
op.addOutputArgument(Nd4j.create(shape));
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -670,6 +670,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
//validateDataType(Nd4j.dataType(), op);
|
||||
|
||||
if(op.z() == null){
|
||||
switch (op.getOpType()) {
|
||||
case SCALAR:
|
||||
op.setZ(op.x().ulike());
|
||||
break;
|
||||
case SCALAR_BOOL:
|
||||
op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape()));
|
||||
break;
|
||||
default:
|
||||
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
|
||||
}
|
||||
}
|
||||
|
||||
if (op.x().length() != op.z().length())
|
||||
throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " +
|
||||
"x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = ["
|
||||
|
@ -1689,7 +1702,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
} catch (ND4JIllegalStateException e){
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
|
|||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
|
||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||
|
@ -8164,6 +8165,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScalarEqualsNoResult(){
|
||||
INDArray out = Nd4j.exec(new ScalarEquals(Nd4j.createFromArray(-2, -1, 0, 1, 2), null, 0));
|
||||
INDArray exp = Nd4j.createFromArray(false, false, true, false, false);
|
||||
assertEquals(exp, out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue