Various fixes (#65)

* #7977 deprecate legacy MultiLayerNetwork/ComputationGraph.params(boolean) method

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix bad test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix Histogram mapping

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix incorrect name handling in DifferentialFunction

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Histogram fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Proper histogram fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ToString/NDArrayStrings fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* JSON UTF8 serialization fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-07-18 18:54:07 +10:00 committed by AlexDBlack
parent c499dc962f
commit d94bc7257c
22 changed files with 54 additions and 79 deletions

View File

@ -401,6 +401,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.seed(12345)
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNIn).nOut(convNOut).hasBias(false)

View File

@ -996,10 +996,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
int[] mbSizes = new int[] {1, 3, 10};
for (int minibatch : mbSizes) {
INDArray in1 = Nd4j.rand(minibatch, 2);
INDArray in2 = Nd4j.rand(minibatch, 2);
INDArray in1 = Nd4j.rand(DataType.DOUBLE, minibatch, 2);
INDArray in2 = Nd4j.rand(DataType.DOUBLE, minibatch, 2);
INDArray labels = Nd4j.rand(minibatch, 1);
INDArray labels = Nd4j.rand(DataType.DOUBLE, minibatch, 1);
String testName = "testBasicL2() - minibatch = " + minibatch;

View File

@ -28,13 +28,12 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
@ -451,10 +450,10 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
//KL divergence: should be a probability distribution for labels??
ret[1] = Nd4j.rand(labelsShape);
if(labelsShape.length == 2){
Nd4j.getExecutioner().exec(new OldSoftMax(ret[1]));
Nd4j.getExecutioner().exec(new SoftMax(ret[1]));
} else if(labelsShape.length == 3) {
for (int i = 0; i < labelsShape[2]; i++) {
Nd4j.getExecutioner().exec(new OldSoftMax(ret[1].get(all(), all(), point(i))));
Nd4j.getExecutioner().exec(new SoftMax(ret[1].get(all(), all(), point(i))));
}
} else {
throw new RuntimeException();

View File

@ -310,8 +310,8 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list()
.layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).build())
.layer(new OutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build())
.setInputType(InputType.recurrent(nClassesIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -324,7 +324,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
int batchSize = 3;
INDArray inEmbedding = Nd4j.create(batchSize, 1);
INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, 1);
INDArray outLabels = Nd4j.create(batchSize, 4);
INDArray outLabels = Nd4j.create(batchSize, 4, 1);
Random r = new Random(1337);
for (int i = 0; i < batchSize; i++) {
@ -333,7 +333,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
inOneHot.putScalar(new int[]{i, classIdx, 0}, 1.0);
int labelIdx = r.nextInt(4);
outLabels.putScalar(new int[]{i, labelIdx}, 1.0);
outLabels.putScalar(new int[]{i, labelIdx, 0}, 1.0);
}
net.setInput(inEmbedding);

View File

@ -2892,26 +2892,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
}
/**
* Get the parameters for the ComputationGraph
*
* @param backwardOnly If true: backprop parameters only (i.e., no visible layer biases used in layerwise pretraining layers)
* @deprecated To be removed. Use {@link #params()}
*/
@Deprecated
public INDArray params(boolean backwardOnly) {
if (backwardOnly)
return flattenedParams;
List<INDArray> list = new ArrayList<>(layers.length);
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[topologicalOrder[i]].hasLayer())
continue;
Layer l = vertices[topologicalOrder[i]].getLayer();
INDArray layerParams = l.params();
if (layerParams != null)
list.add(layerParams); //may be null: subsampling etc layers
}
return Nd4j.toFlattened('f', list);
return params();
}
/**
@ -3183,7 +3168,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
@Override
public INDArray params() {
return params(true);
return flattenedParams;
}
@Override

View File

@ -1433,29 +1433,11 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
/**
* Returns a 1 x m vector where the vector is composed of a flattened vector of all of the parameters (weights and
* biases etc) for all parameters in the network. Note that this method is generally reserved for developer and
* internal use - see {@link #getParam(String)} and {@link #paramTable()} for a more useful/interpretable
* representation of the parameters.<br>
* Note that with backwardsOnly = false the parameter vector is not a copy, and changes to the returned INDArray
* will impact the network parameters.
*
* @param backwardOnly Return a copy of the parameters excluding any parameters used only for unsupervised layers'
* unsupervised training (such as decoder parameters in an autoencoder layer
* @return the params for this neural net
* @deprecated To be removed. Use {@link #params()} instead
*/
@Deprecated
public INDArray params(boolean backwardOnly) {
if (backwardOnly)
return params();
List<INDArray> params = new ArrayList<>();
for (Layer layer : getLayers()) {
INDArray layerParams = layer.params();
if (layerParams != null)
params.add(layerParams); //may be null: subsampling etc layers
}
return Nd4j.toFlattened('f', params);
}

View File

@ -41,10 +41,6 @@ public class WeightInitUtil {
private WeightInitUtil() {}
public static INDArray initWeights(int[] shape, float min, float max) {
return Nd4j.rand(shape, min, max, Nd4j.getRandom());
}
/**
* Initializes a matrix with the given weight initialization scheme.

View File

@ -28,7 +28,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.ui.stats.api.*;
import org.deeplearning4j.ui.stats.impl.DefaultStatsInitializationConfiguration;
import org.deeplearning4j.ui.stats.impl.DefaultStatsUpdateConfiguration;
@ -763,11 +762,11 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
for (Map.Entry<String, INDArray> entry : map.entrySet()) {
org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram hOp =
new org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram(entry.getValue(), nBins);
Nd4j.getExecutioner().exec(hOp);
org.nd4j.linalg.api.ops.impl.transforms.Histogram hOp =
new org.nd4j.linalg.api.ops.impl.transforms.Histogram(entry.getValue(), nBins);
Nd4j.exec(hOp);
INDArray bins = hOp.z();
INDArray bins = hOp.getOutputArgument(0);
int[] count = new int[nBins];
for (int i = 0; i < bins.length(); i++) {
count[i] = (int) bins.getDouble(i);

View File

@ -34,6 +34,7 @@ namespace nd4j {
REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length")
output->nullify();
helpers::histogramHelper(block.launchContext(), *input, *output);
return Status::OK();

View File

@ -663,9 +663,9 @@ public abstract class DifferentialFunction {
scope = "";
else
scope = scope + "/";
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex).replace(":", "_");
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex);
while(sameDiff.functionExists(varName)) {
varName = scope + sameDiff.generateNewVarName(opName(), argIndex).replace(":", "_");
varName = scope + sameDiff.generateNewVarName(opName(), argIndex);
argIndex++;
}

View File

@ -4589,9 +4589,10 @@ public class SameDiff extends SDBaseOps {
CustomOp op = (CustomOp)node;
extras = op.tArgs();
} else {
extras = node.getExtraArgs() != null ? new double[node.getExtraArgs().length] : new double[0];
Object[] eArgs = node.getExtraArgs();
extras = eArgs != null ? new double[eArgs.length] : new double[0];
for (int e = 0; e < extras.length; e++) {
extras[e] = ((Number) node.getExtraArgs()[e]).doubleValue();
extras[e] = ((Number) eArgs[e]).doubleValue();
}
}

View File

@ -331,6 +331,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class,
org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class,
org.nd4j.linalg.api.ops.impl.transforms.Constant.class,
org.nd4j.linalg.api.ops.impl.transforms.Histogram.class,
org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class,
org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class,
org.nd4j.linalg.api.ops.impl.transforms.MaxOut.class,

View File

@ -48,7 +48,7 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
SDVariable i_v,
boolean keepDims,
int[] dimensions) {
super(sameDiff,new Object[]{dimensions});
super(sameDiff,null);
if (i_v != null) {
this.dimensions = dimensions;
f().validateDifferentialFunctionsameDiff(i_v);
@ -70,7 +70,7 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
SDVariable i_v2,
boolean keepDims,
int[] dimensions) {
super(sameDiff,new Object[]{dimensions});
super(sameDiff,null);
if (i_v != null) {
this.dimensions = dimensions;
f().validateDifferentialFunctionsameDiff(i_v);

View File

@ -61,7 +61,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
public BaseReduceOp(SameDiff sameDiff,
SDVariable i_v,
int[] dimensions, boolean keepDims) {
super(sameDiff,new Object[]{dimensions});
super(sameDiff, null);
if (i_v != null) {
if(dimensions == null || dimensions.length < 1)
dimensions = new int[] {Integer.MAX_VALUE};
@ -86,7 +86,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
SDVariable i_v,
SDVariable i_v2,
int[] dimensions, boolean keepDims) {
super(sameDiff,new Object[]{dimensions});
super(sameDiff,null);
if (i_v != null) {
if(dimensions == null || dimensions.length < 1)
dimensions = new int[] {Integer.MAX_VALUE};

View File

@ -2089,7 +2089,9 @@ public class Shape {
}
// we need to wrap buffer of a current array, to make sure it's properly marked as a View
INDArray ret = Nd4j.create(Nd4j.createBuffer(arr.data(), arr.offset(), arr.length()), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
DataBuffer db = arr.data();
DataBuffer buffer = Nd4j.createBuffer(db, arr.offset(), arr.length());
INDArray ret = Nd4j.create(buffer, newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
return ret;
}

View File

@ -305,7 +305,7 @@ public class NDArrayStrings {
}
}
if (i < l - 1) {
if (!summarize || i < 2 || i > l - 3 || (summarize && l == 6)) {
if (!summarize || i <= 2 || i >= l - 3 || (summarize && l == 6)) {
sb.append(colSep);
}
}

View File

@ -17,6 +17,7 @@
package org.nd4j.serde.jackson.shaded;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.serde.base64.Nd4jBase64;
@ -76,9 +77,12 @@ public class NDArrayTextSerializer extends JsonSerializer<INDArray> {
jg.writeNumber(v);
break;
case UTF8:
String[] str = new String[(int)arr.length()];
for( int j=0; j<str.length; j++ )
jg.writeString(arr.getString(j));
Utf8Buffer utf8B = ((Utf8Buffer)arr.data());
long n = utf8B.getNumWords();
for( int j=0; j<n; j++ ) {
String s = utf8B.getString(j);
jg.writeString(s);
}
break;
case COMPRESSED:
case UNKNOWN:

View File

@ -68,7 +68,7 @@ public class TestOpMapping extends BaseNd4jTest {
}
String opName = df.opName();
assertTrue(opName, opNameMapping.containsKey(opName));
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName, opNameMapping.containsKey(opName));
try{
String[] tfNames = df.tensorflowNames();

View File

@ -44,7 +44,7 @@ public class ToStringTest extends BaseNd4jTest {
assertEquals("[ 1, 2, 3]",
Nd4j.createFromArray(1, 2, 3).toString());
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
assertEquals("[ 1, 2, 3, 4, 5, 6, 7, 8]",
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(1000, false, 2));
assertEquals("[ 1.132, 2.644, 3.234]",
@ -53,9 +53,8 @@ public class ToStringTest extends BaseNd4jTest {
assertEquals("[ 1.132414, 2.64356456, 3.25345234]",
Nd4j.createFromArray(1.132414, 2.64356456, 3.25345234).toStringFull());
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(100, true, 1));
assertEquals("[ 1, 2, 3, ... 6, 7, 8]",
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(6, true, 1));
}
@Override

View File

@ -66,7 +66,7 @@ public class JsonSerdeTests extends BaseNd4jTest {
INDArray arr;
if(dt == DataType.UTF8){
arr = Nd4j.create("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l").reshape('c', 3, 4);
arr = Nd4j.create("aaaaa", "bbbb", "ccc", "dd", "e", "f", "g", "h", "i", "j", "k", "l").reshape('c', 3, 4);
} else {
arr = in.castTo(dt);
}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.buffer;
import lombok.Getter;
import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.BytePointer;
@ -42,6 +43,7 @@ public class Utf8Buffer extends BaseDataBuffer {
protected Collection<Pointer> references = new ArrayList<>();
@Getter
protected long numWords = 0;
/**
@ -121,6 +123,7 @@ public class Utf8Buffer extends BaseDataBuffer {
public Utf8Buffer(DataBuffer underlyingBuffer, long length, long offset) {
super(underlyingBuffer, length, offset);
this.numWords = length;
}
public Utf8Buffer(@NonNull Collection<String> strings) {

View File

@ -87,6 +87,8 @@ public class DefaultDataBufferFactory implements DataBufferFactory {
return new BFloat16Buffer(underlyingBuffer, length, offset);
} else if (underlyingBuffer.dataType() == DataType.HALF) {
return new HalfBuffer(underlyingBuffer, length, offset);
} else if (underlyingBuffer.dataType() == DataType.UTF8) {
return new Utf8Buffer(underlyingBuffer, length, offset);
}
return null;
}