Various DL4J/ND4J fixes (#81)

* #7954 Force refresh of UI when switching tabs on overview page

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8017 Concurrent modification exception (synchronize) fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8033 Don't initialize updater in middle of writing memory crash dump

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8208 Fix shape checks for ND4J int[] creator methods

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #6385 #7992 Keras import naming fixes + cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8016 Upsampling3D - add NDHWC format support

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-26 00:05:24 +10:00 committed by AlexDBlack
parent 7c5c84bea8
commit fad8da878f
13 changed files with 204 additions and 117 deletions

View File

@ -386,12 +386,13 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
for (Activation afn : activations) { for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) { for (int miniBatchSize : minibatchSizes) {
for (ConvolutionMode mode : modes) { for (ConvolutionMode mode : modes) {
for(Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
int outDepth = depth * upsamplingSize[0]; int outDepth = depth * upsamplingSize[0];
int outHeight = height * upsamplingSize[1]; int outHeight = height * upsamplingSize[1];
int outWidth = width * upsamplingSize[2]; int outWidth = width * upsamplingSize[2];
INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); INDArray input = df == Convolution3D.DataFormat.NCDHW ? Nd4j.rand(miniBatchSize, convNIn, depth, height, width) : Nd4j.rand(miniBatchSize, depth, height, width, convNIn);
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0); labels.putScalar(new int[]{i, i % finalNOut}, 1.0);
@ -405,16 +406,16 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
.list() .list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1) .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNIn).nOut(convNOut).hasBias(false) .nIn(convNIn).nOut(convNOut).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW) .convolutionMode(mode).dataFormat(df)
.build()) .build())
.layer(1, new Upsampling3D.Builder(upsamplingSize[0]).build()) .layer(1, new Upsampling3D.Builder(upsamplingSize[0]).dataFormat(df).build())
.layer(2, new DenseLayer.Builder().nOut(denseNOut).build()) .layer(2, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build()) .activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(2, .inputPreProcessor(2,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut, true)) convNOut, true))
.setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
@ -442,7 +443,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
}
} }
} }
} }

View File

@ -32,6 +32,17 @@ import java.util.Map;
*/ */
@Slf4j @Slf4j
public class KerasOptimizerUtils { public class KerasOptimizerUtils {
protected static final String LR = "lr";
protected static final String LR2 = "learning_rate";
protected static final String EPSILON = "epsilon";
protected static final String MOMENTUM = "momentum";
protected static final String BETA_1 = "beta_1";
protected static final String BETA_2 = "beta_2";
protected static final String DECAY = "decay";
protected static final String RHO = "rho";
protected static final String SCHEDULE_DECAY = "schedule_decay";
/** /**
* Map Keras optimizer to DL4J IUpdater. * Map Keras optimizer to DL4J IUpdater.
* *
@ -55,11 +66,11 @@ public class KerasOptimizerUtils {
switch (optimizerName) { switch (optimizerName) {
case "Adam": { case "Adam": {
double lr = (double) optimizerParameters.get("lr"); double lr = (double) (optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
double beta1 = (double) optimizerParameters.get("beta_1"); double beta1 = (double) optimizerParameters.get(BETA_1);
double beta2 = (double) optimizerParameters.get("beta_2"); double beta2 = (double) optimizerParameters.get(BETA_2);
double epsilon = (double) optimizerParameters.get("epsilon"); double epsilon = (double) optimizerParameters.get(EPSILON);
double decay = (double) optimizerParameters.get("decay"); double decay = (double) optimizerParameters.get(DECAY);
dl4jOptimizer = new Adam.Builder() dl4jOptimizer = new Adam.Builder()
.beta1(beta1).beta2(beta2) .beta1(beta1).beta2(beta2)
@ -69,9 +80,9 @@ public class KerasOptimizerUtils {
break; break;
} }
case "Adadelta": { case "Adadelta": {
double rho = (double) optimizerParameters.get("rho"); double rho = (double) optimizerParameters.get(RHO);
double epsilon = (double) optimizerParameters.get("epsilon"); double epsilon = (double) optimizerParameters.get(EPSILON);
// double decay = (double) optimizerParameters.get("decay"); No decay in DL4J Adadelta // double decay = (double) optimizerParameters.get(DECAY); No decay in DL4J Adadelta
dl4jOptimizer = new AdaDelta.Builder() dl4jOptimizer = new AdaDelta.Builder()
.epsilon(epsilon).rho(rho) .epsilon(epsilon).rho(rho)
@ -79,9 +90,9 @@ public class KerasOptimizerUtils {
break; break;
} }
case "Adgrad": { case "Adgrad": {
double lr = (double) optimizerParameters.get("lr"); double lr = (double) (optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
double epsilon = (double) optimizerParameters.get("epsilon"); double epsilon = (double) optimizerParameters.get(EPSILON);
double decay = (double) optimizerParameters.get("decay"); double decay = (double) optimizerParameters.get(DECAY);
dl4jOptimizer = new AdaGrad.Builder() dl4jOptimizer = new AdaGrad.Builder()
.epsilon(epsilon).learningRate(lr) .epsilon(epsilon).learningRate(lr)
@ -90,20 +101,20 @@ public class KerasOptimizerUtils {
break; break;
} }
case "Adamax": { case "Adamax": {
double lr = (double) optimizerParameters.get("lr"); double lr = (double) (optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
double beta1 = (double) optimizerParameters.get("beta_1"); double beta1 = (double) optimizerParameters.get(BETA_1);
double beta2 = (double) optimizerParameters.get("beta_2"); double beta2 = (double) optimizerParameters.get(BETA_2);
double epsilon = (double) optimizerParameters.get("epsilon"); double epsilon = (double) optimizerParameters.get(EPSILON);
dl4jOptimizer = new AdaMax(lr, beta1, beta2, epsilon); dl4jOptimizer = new AdaMax(lr, beta1, beta2, epsilon);
break; break;
} }
case "Nadam": { case "Nadam": {
double lr = (double) optimizerParameters.get("lr"); double lr = (double) (optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
double beta1 = (double) optimizerParameters.get("beta_1"); double beta1 = (double) optimizerParameters.get(BETA_1);
double beta2 = (double) optimizerParameters.get("beta_2"); double beta2 = (double) optimizerParameters.get(BETA_2);
double epsilon = (double) optimizerParameters.get("epsilon"); double epsilon = (double) optimizerParameters.get(EPSILON);
double scheduleDecay = (double) optimizerParameters.get("schedule_decay"); double scheduleDecay = (double) optimizerParameters.get(SCHEDULE_DECAY);
dl4jOptimizer = new Nadam.Builder() dl4jOptimizer = new Nadam.Builder()
.beta1(beta1).beta2(beta2) .beta1(beta1).beta2(beta2)
@ -114,15 +125,10 @@ public class KerasOptimizerUtils {
break; break;
} }
case "SGD": { case "SGD": {
double lr = (double) optimizerParameters.get("lr"); double lr = (double) (optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
double momentum = 0.0; double momentum = (double) (optimizerParameters.containsKey(EPSILON) ? optimizerParameters.get(EPSILON) : optimizerParameters.get(MOMENTUM));
try {
momentum = (double) optimizerParameters.get("epsilon");
} catch (Exception e) {
log.warn("couldn't read momentum parameter");
}
double decay = (double) optimizerParameters.get("decay"); double decay = (double) optimizerParameters.get(DECAY);
dl4jOptimizer = new Nesterovs.Builder() dl4jOptimizer = new Nesterovs.Builder()
.momentum(momentum).learningRate(lr) .momentum(momentum).learningRate(lr)
@ -131,10 +137,10 @@ public class KerasOptimizerUtils {
break; break;
} }
case "RMSprop": { case "RMSprop": {
double lr = (double) optimizerParameters.get("lr"); double lr = (double) (optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
double rho = (double) optimizerParameters.get("rho"); double rho = (double) optimizerParameters.get(RHO);
double epsilon = (double) optimizerParameters.get("epsilon"); double epsilon = (double) optimizerParameters.get(EPSILON);
double decay = (double) optimizerParameters.get("decay"); double decay = (double) optimizerParameters.get(DECAY);
dl4jOptimizer = new RmsProp.Builder() dl4jOptimizer = new RmsProp.Builder()
.epsilon(epsilon).rmsDecay(rho).learningRate(lr) .epsilon(epsilon).rmsDecay(rho).learningRate(lr)

View File

@ -45,10 +45,14 @@ import java.util.Map;
public class Upsampling3D extends BaseUpsamplingLayer { public class Upsampling3D extends BaseUpsamplingLayer {
protected int[] size; protected int[] size;
protected Convolution3D.DataFormat dataFormat = Convolution3D.DataFormat.NCDHW; //Default to NCDHW for 1.0.0-beta4 and earlier, when no config existed (NCDHW only)
protected Upsampling3D(UpsamplingBuilder builder) {
protected Upsampling3D(Builder builder) {
super(builder); super(builder);
this.size = builder.size; this.size = builder.size;
this.dataFormat = builder.dataFormat;
} }
@Override @Override
@ -124,10 +128,32 @@ public class Upsampling3D extends BaseUpsamplingLayer {
@NoArgsConstructor @NoArgsConstructor
public static class Builder extends UpsamplingBuilder<Builder> { public static class Builder extends UpsamplingBuilder<Builder> {
protected Convolution3D.DataFormat dataFormat = Convolution3D.DataFormat.NCDHW;
/**
* @param size Upsampling layer size (most common value: 2)
*/
public Builder(int size) { public Builder(int size) {
super(new int[] {size, size, size}); super(new int[] {size, size, size});
} }
/**
* @param dataFormat Data format - see {@link Convolution3D.DataFormat} for more details
* @param size Upsampling layer size (most common value: 2)
*/
public Builder(@NonNull Convolution3D.DataFormat dataFormat, int size){
super(new int[]{size, size, size});
this.dataFormat = dataFormat;
}
/**
* Sets the DataFormat. See {@link Convolution3D.DataFormat} for more details
*/
public Builder dataFormat(@NonNull Convolution3D.DataFormat dataFormat){
this.dataFormat = dataFormat;
return this;
}
/** /**
* Upsampling size as int, so same upsampling size is used for depth, width and height * Upsampling size as int, so same upsampling size is used for depth, width and height
* *

View File

@ -2896,7 +2896,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this)); solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
} }
if(solver != null) { if(solver != null) {
return solver.getOptimizer().getComputationGraphUpdater(); return solver.getOptimizer().getComputationGraphUpdater(initializeIfAbsent);
} }
return null; return null;
} }

View File

@ -67,18 +67,36 @@ public class Upsampling3D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(true); assertInputSet(true);
boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
// FIXME: int cast // FIXME: int cast
// Assumes NCDHW order // Assumes NCDHW order
int miniBatch = (int) input.size(0); int miniBatch = (int) input.size(0);
int inChannels = (int) input.size(1); int inChannels, inD, inH, inW;
int inD = (int) input.size(2); int[] intArgs;
int inH = (int) input.size(3); if(ncdhw){
int inW = (int) input.size(4); inChannels = (int) input.size(1);
inD = (int) input.size(2);
inH = (int) input.size(3);
inW = (int) input.size(4);
intArgs = new int[] {1}; // 1 is channels first
} else {
inD = (int) input.size(1);
inH = (int) input.size(2);
inW = (int) input.size(3);
inChannels = (int) input.size(4);
intArgs = new int[] {0}; // 0 is channels last
}
int[] intArgs = new int[] {1}; // 1 is channels first
INDArray reshapedEpsilon = workspaceMgr.createUninitialized(
INDArray epsOut;
if(ncdhw){
epsOut = workspaceMgr.createUninitialized(
ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inChannels, inD, inH, inW}, 'c'); ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inChannels, inD, inH, inW}, 'c');
} else {
epsOut = workspaceMgr.createUninitialized(
ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inD, inH, inW, inChannels}, 'c');
}
Gradient gradient = new DefaultGradient(); Gradient gradient = new DefaultGradient();
@ -86,13 +104,13 @@ public class Upsampling3D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
CustomOp op = DynamicCustomOp.builder("upsampling3d_bp") CustomOp op = DynamicCustomOp.builder("upsampling3d_bp")
.addIntegerArguments(intArgs) .addIntegerArguments(intArgs)
.addInputs(input, epsilon) .addInputs(input, epsilon)
.addOutputs(reshapedEpsilon) .addOutputs(epsOut)
.callInplace(false) .callInplace(false)
.build(); .build();
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
reshapedEpsilon = backpropDropOutIfPresent(reshapedEpsilon); epsOut = backpropDropOutIfPresent(epsOut);
return new Pair<>(gradient, reshapedEpsilon); return new Pair<>(gradient, epsOut);
} }
protected int[] getSize() { protected int[] getSize() {
@ -115,32 +133,51 @@ public class Upsampling3D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
return preOutput; return preOutput;
} }
long miniBatch = (int) input.size(0); boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
long inChannels = (int) input.size(1); // FIXME: int cast
long inD = (int) input.size(2); int miniBatch = (int) input.size(0);
long inH = (int) input.size(3); int inChannels, inD, inH, inW;
long inW = (int) input.size(4); int[] intArgs;
int[] size = getSize(); int[] size = getSize();
if(ncdhw){
inChannels = (int) input.size(1);
inD = (int) input.size(2);
inH = (int) input.size(3);
inW = (int) input.size(4);
intArgs = new int[] {size[0], size[1], size[2], 1}; // 1 is channels first
} else {
inD = (int) input.size(1);
inH = (int) input.size(2);
inW = (int) input.size(3);
inChannels = (int) input.size(4);
intArgs = new int[] {size[0], size[1], size[2], 0}; // 0 is channels last
}
long outD = inD * size[0]; long outD = inD * size[0];
long outH = inH * size[1]; long outH = inH * size[1];
long outW = inW * size[2]; long outW = inW * size[2];
int[] intArgs = new int[] {size[0], size[1], size[2], 1}; // 1 is channels first INDArray output;
if(ncdhw){
INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS,
input.dataType(), new long[]{miniBatch, inChannels, outD, outH, outW}, 'c'); input.dataType(), new long[]{miniBatch, inChannels, outD, outH, outW}, 'c');
} else {
output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS,
input.dataType(), new long[]{miniBatch, outD, outH, outW, inChannels}, 'c');
}
CustomOp upsampling = DynamicCustomOp.builder("upsampling3d") CustomOp upsampling = DynamicCustomOp.builder("upsampling3d")
.addIntegerArguments(intArgs) .addIntegerArguments(intArgs)
.addInputs(input) .addInputs(input)
.addOutputs(reshapedOutput) .addOutputs(output)
.callInplace(false) .callInplace(false)
.build(); .build();
Nd4j.getExecutioner().exec(upsampling); Nd4j.getExecutioner().exec(upsampling);
return reshapedOutput; return output;
} }
@Override @Override

View File

@ -3172,7 +3172,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
} }
if(solver != null) { if(solver != null) {
return solver.getOptimizer().getUpdater(); return solver.getOptimizer().getUpdater(initializeIfReq);
} }
return null; return null;
} }

View File

@ -42,8 +42,12 @@ public interface ConvexOptimizer extends Serializable {
Updater getUpdater(); Updater getUpdater();
Updater getUpdater(boolean initializeIfReq);
ComputationGraphUpdater getComputationGraphUpdater(); ComputationGraphUpdater getComputationGraphUpdater();
ComputationGraphUpdater getComputationGraphUpdater(boolean initializeIfReq);
void setUpdater(Updater updater); void setUpdater(Updater updater);
void setUpdaterComputationGraph(ComputationGraphUpdater updater); void setUpdaterComputationGraph(ComputationGraphUpdater updater);

View File

@ -115,7 +115,12 @@ public abstract class BaseOptimizer implements ConvexOptimizer {
@Override @Override
public Updater getUpdater() { public Updater getUpdater() {
if (updater == null) { return getUpdater(true);
}
@Override
public Updater getUpdater(boolean initializeIfReq) {
if (updater == null && initializeIfReq) {
updater = UpdaterCreator.getUpdater(model); updater = UpdaterCreator.getUpdater(model);
} }
return updater; return updater;
@ -130,7 +135,12 @@ public abstract class BaseOptimizer implements ConvexOptimizer {
@Override @Override
public ComputationGraphUpdater getComputationGraphUpdater() { public ComputationGraphUpdater getComputationGraphUpdater() {
if (computationGraphUpdater == null && model instanceof ComputationGraph) { return getComputationGraphUpdater(true);
}
@Override
public ComputationGraphUpdater getComputationGraphUpdater(boolean initializIfReq) {
if (computationGraphUpdater == null && model instanceof ComputationGraph && initializIfReq) {
computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph) model); computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph) model);
} }
return computationGraphUpdater; return computationGraphUpdater;

View File

@ -205,7 +205,7 @@ public class CrashReportingUtil {
StringBuilder sb = genericMemoryStatus(); StringBuilder sb = genericMemoryStatus();
int bytesPerElement; int bytesPerElement;
switch (Nd4j.dataType()){ switch (isMLN ? mln.params().dataType() : cg.params().dataType()){
case DOUBLE: case DOUBLE:
bytesPerElement = 8; bytesPerElement = 8;
break; break;

View File

@ -200,7 +200,7 @@ public class TrainModule implements UIModule {
* List training sessions * List training sessions
* @return HTML list of training sessions * @return HTML list of training sessions
*/ */
private Result listSessions() { private synchronized Result listSessions() {
StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" + StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" +
"<html lang=\"en\">\n" + "<html lang=\"en\">\n" +
"<head>\n" + "<head>\n" +
@ -464,7 +464,7 @@ public class TrainModule implements UIModule {
* @param sessionId session ID * @param sessionId session ID
* @return info for session as JSON * @return info for session as JSON
*/ */
private Result sessionInfoForSession(String sessionId) { private synchronized Result sessionInfoForSession(String sessionId) {
Map<String, Object> dataEachSession = new HashMap<>(); Map<String, Object> dataEachSession = new HashMap<>();
StatsStorage ss = knownSessionIDs.get(sessionId); StatsStorage ss = knownSessionIDs.get(sessionId);
@ -475,7 +475,7 @@ public class TrainModule implements UIModule {
return Results.ok(asJson(dataEachSession)).as("application/json"); return Results.ok(asJson(dataEachSession)).as("application/json");
} }
private Result setSession(String newSessionID) { private synchronized Result setSession(String newSessionID) {
if (knownSessionIDs.containsKey(newSessionID)) { if (knownSessionIDs.containsKey(newSessionID)) {
currentSessionID = newSessionID; currentSessionID = newSessionID;
currentWorkerIdx = 0; currentWorkerIdx = 0;
@ -567,7 +567,7 @@ public class TrainModule implements UIModule {
return getOverviewDataForSession(currentSessionID); return getOverviewDataForSession(currentSessionID);
} }
private Result getOverviewDataForSession(String sessionId) { private synchronized Result getOverviewDataForSession(String sessionId) {
Long lastUpdateTime = getLastUpdateTime(sessionId); Long lastUpdateTime = getLastUpdateTime(sessionId);
I18N i18N = getI18N(sessionId); I18N i18N = getI18N(sessionId);

View File

@ -20,6 +20,8 @@ function selectStdevChart(fieldName) {
$("#stdevGradients").removeAttr("class"); $("#stdevGradients").removeAttr("class");
$("#stdevUpdates").attr("class", "active"); $("#stdevUpdates").attr("class", "active");
} }
renderOverviewPage(false);
} }
/* ---------- Render page ---------- */ /* ---------- Render page ---------- */

View File

@ -5207,7 +5207,7 @@ public class Nd4j {
*/ */
public static void checkShapeValues(int... shape) { public static void checkShapeValues(int... shape) {
for (int e: shape) { for (int e: shape) {
if (e < 1) if (e < 0)
throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape)
+ " contains dimension size values < 0 (all dimensions must be 0 or more)"); + " contains dimension size values < 0 (all dimensions must be 0 or more)");
} }

View File

@ -256,6 +256,7 @@ public class EmptyTests extends BaseNd4jTest {
assertArrayEquals(new long[]{0}, Nd4j.zeros(0).shape()); assertArrayEquals(new long[]{0}, Nd4j.zeros(0).shape());
assertArrayEquals(new long[]{0,0}, Nd4j.zeros(0,0).shape()); assertArrayEquals(new long[]{0,0}, Nd4j.zeros(0,0).shape());
assertArrayEquals(new long[]{0,0,0}, Nd4j.zeros(0,0,0).shape()); assertArrayEquals(new long[]{0,0,0}, Nd4j.zeros(0,0,0).shape());
assertArrayEquals(new long[]{0,0,0}, Nd4j.zeros(new int[]{0,0,0}, 'f').shape());
assertArrayEquals(new long[]{0}, Nd4j.zeros(0L).shape()); assertArrayEquals(new long[]{0}, Nd4j.zeros(0L).shape());
assertArrayEquals(new long[]{0}, Nd4j.zeros(dt, 0L).shape()); assertArrayEquals(new long[]{0}, Nd4j.zeros(dt, 0L).shape());