From d94bc7257c5ebe7751f01686706432b223fd9cc2 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 18 Jul 2019 18:54:07 +1000 Subject: [PATCH] Various fixes (#65) * #7977 deprecate legacy MultiLayerNetwork/ComputationGraph.params(boolean) method Signed-off-by: AlexDBlack * Fix bad test Signed-off-by: AlexDBlack * Small fixes Signed-off-by: AlexDBlack * Fix Histogram mapping Signed-off-by: AlexDBlack * Fix incorrect name handling in DifferentialFunction Signed-off-by: AlexDBlack * More fixes Signed-off-by: AlexDBlack * Histogram fixes Signed-off-by: AlexDBlack * Proper histogram fix Signed-off-by: AlexDBlack * ToString/NDArrayStrings fix Signed-off-by: AlexDBlack * JSON UTF8 serialization fix Signed-off-by: AlexDBlack --- .../gradientcheck/CNN3DGradientCheckTest.java | 1 + .../GradientCheckTestsComputationGraph.java | 6 ++--- .../LossFunctionGradientCheck.java | 7 +++--- .../embedding/EmbeddingLayerTest.java | 8 +++---- .../nn/graph/ComputationGraph.java | 23 ++++-------------- .../nn/multilayer/MultiLayerNetwork.java | 24 +++---------------- .../nn/weights/WeightInitUtil.java | 4 ---- .../ui/stats/BaseStatsListener.java | 9 ++++--- .../generic/transforms/histogram.cpp | 1 + .../functions/DifferentialFunction.java | 4 ++-- .../org/nd4j/autodiff/samediff/SameDiff.java | 5 ++-- .../converters/ImportClassMapping.java | 1 + .../linalg/api/ops/BaseIndexAccumulation.java | 4 ++-- .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 4 ++-- .../java/org/nd4j/linalg/api/shape/Shape.java | 4 +++- .../nd4j/linalg/string/NDArrayStrings.java | 2 +- .../jackson/shaded/NDArrayTextSerializer.java | 10 +++++--- .../java/org/nd4j/autodiff/TestOpMapping.java | 2 +- .../java/org/nd4j/linalg/ToStringTest.java | 7 +++--- .../org/nd4j/linalg/serde/JsonSerdeTests.java | 2 +- .../nd4j/linalg/api/buffer/Utf8Buffer.java | 3 +++ .../factory/DefaultDataBufferFactory.java | 2 ++ 22 files changed, 54 insertions(+), 79 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index 01fe1fe9b..227b72527 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -401,6 +401,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) + .seed(12345) .list() .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1) .nIn(convNIn).nOut(convNOut).hasBias(false) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index e0e7130ed..1eb893e3f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -996,10 +996,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int[] mbSizes = new int[] {1, 3, 10}; for (int minibatch : mbSizes) { - INDArray in1 = Nd4j.rand(minibatch, 2); - INDArray in2 = Nd4j.rand(minibatch, 2); + INDArray in1 = Nd4j.rand(DataType.DOUBLE, minibatch, 2); + INDArray in2 = Nd4j.rand(DataType.DOUBLE, minibatch, 2); - INDArray labels = Nd4j.rand(minibatch, 1); + INDArray labels = Nd4j.rand(DataType.DOUBLE, minibatch, 1); String testName = "testBasicL2() - minibatch = " + minibatch; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index bf06551a1..a68aa9a57 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -28,13 +28,12 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; @@ -451,10 +450,10 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { //KL divergence: should be a probability distribution for labels?? ret[1] = Nd4j.rand(labelsShape); if(labelsShape.length == 2){ - Nd4j.getExecutioner().exec(new OldSoftMax(ret[1])); + Nd4j.getExecutioner().exec(new SoftMax(ret[1])); } else if(labelsShape.length == 3) { for (int i = 0; i < labelsShape[2]; i++) { - Nd4j.getExecutioner().exec(new OldSoftMax(ret[1].get(all(), all(), point(i)))); + Nd4j.getExecutioner().exec(new SoftMax(ret[1].get(all(), all(), point(i)))); } } else { throw new RuntimeException(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index d8921346d..20a6b34cf 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -310,8 +310,8 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .build(); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).build()) - .layer(new OutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) + .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) + .setInputType(InputType.recurrent(nClassesIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -324,7 +324,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, 1); - INDArray outLabels = Nd4j.create(batchSize, 4); + INDArray outLabels = Nd4j.create(batchSize, 4, 1); Random r = new Random(1337); for (int i = 0; i < batchSize; i++) { @@ -333,7 +333,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { inOneHot.putScalar(new int[]{i, classIdx, 0}, 1.0); int labelIdx = r.nextInt(4); - outLabels.putScalar(new int[]{i, labelIdx}, 1.0); + outLabels.putScalar(new int[]{i, labelIdx, 0}, 1.0); } net.setInput(inEmbedding); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index e3e080114..daa91be7f 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -2892,26 +2892,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } /** - * Get the parameters for the ComputationGraph - * - * @param backwardOnly If true: backprop parameters only (i.e., no visible layer biases used in layerwise pretraining layers) + * @deprecated To be removed. Use {@link #params()} */ + @Deprecated public INDArray params(boolean backwardOnly) { - if (backwardOnly) - return flattenedParams; - - List list = new ArrayList<>(layers.length); - for (int i = 0; i < topologicalOrder.length; i++) { - if (!vertices[topologicalOrder[i]].hasLayer()) - continue; - - Layer l = vertices[topologicalOrder[i]].getLayer(); - INDArray layerParams = l.params(); - if (layerParams != null) - list.add(layerParams); //may be null: subsampling etc layers - } - - return Nd4j.toFlattened('f', list); + return params(); } /** @@ -3183,7 +3168,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { @Override public INDArray params() { - return params(true); + return flattenedParams; } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index ef09f7780..c60e853b2 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -1433,29 +1433,11 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura /** - * Returns a 1 x m vector where the vector is composed of a flattened vector of all of the parameters (weights and - * biases etc) for all parameters in the network. Note that this method is generally reserved for developer and - * internal use - see {@link #getParam(String)} and {@link #paramTable()} for a more useful/interpretable - * representation of the parameters.
- * Note that with backwardsOnly = false the parameter vector is not a copy, and changes to the returned INDArray - * will impact the network parameters. - * - * @param backwardOnly Return a copy of the parameters excluding any parameters used only for unsupervised layers' - * unsupervised training (such as decoder parameters in an autoencoder layer - * @return the params for this neural net + * @deprecated To be removed. Use {@link #params()} instead */ + @Deprecated public INDArray params(boolean backwardOnly) { - if (backwardOnly) - return params(); - - List params = new ArrayList<>(); - for (Layer layer : getLayers()) { - INDArray layerParams = layer.params(); - if (layerParams != null) - params.add(layerParams); //may be null: subsampling etc layers - } - - return Nd4j.toFlattened('f', params); + return params(); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java index 6bf62de1d..b110bc5a0 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/weights/WeightInitUtil.java @@ -41,10 +41,6 @@ public class WeightInitUtil { private WeightInitUtil() {} - public static INDArray initWeights(int[] shape, float min, float max) { - return Nd4j.rand(shape, min, max, Nd4j.getRandom()); - } - /** * Initializes a matrix with the given weight initialization scheme. diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java index 1cb9014b6..233633ab9 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java @@ -28,7 +28,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.ui.stats.api.*; import org.deeplearning4j.ui.stats.impl.DefaultStatsInitializationConfiguration; import org.deeplearning4j.ui.stats.impl.DefaultStatsUpdateConfiguration; @@ -763,11 +762,11 @@ public abstract class BaseStatsListener implements RoutingIterationListener { for (Map.Entry entry : map.entrySet()) { - org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram hOp = - new org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram(entry.getValue(), nBins); - Nd4j.getExecutioner().exec(hOp); + org.nd4j.linalg.api.ops.impl.transforms.Histogram hOp = + new org.nd4j.linalg.api.ops.impl.transforms.Histogram(entry.getValue(), nBins); + Nd4j.exec(hOp); - INDArray bins = hOp.z(); + INDArray bins = hOp.getOutputArgument(0); int[] count = new int[nBins]; for (int i = 0; i < bins.length(); i++) { count[i] = (int) bins.getDouble(i); diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp index 3581dcd9a..ab5a70c4b 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp @@ -34,6 +34,7 @@ namespace nd4j { REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length") + output->nullify(); helpers::histogramHelper(block.launchContext(), *input, *output); return Status::OK(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 133686b57..5f5ab2a7c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -663,9 +663,9 @@ public abstract class DifferentialFunction { scope = ""; else scope = scope + "/"; - String varName = scope + sameDiff.generateNewVarName(opName(),argIndex).replace(":", "_"); + String varName = scope + sameDiff.generateNewVarName(opName(),argIndex); while(sameDiff.functionExists(varName)) { - varName = scope + sameDiff.generateNewVarName(opName(), argIndex).replace(":", "_"); + varName = scope + sameDiff.generateNewVarName(opName(), argIndex); argIndex++; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 2e7829913..2b2ba63dc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -4589,9 +4589,10 @@ public class SameDiff extends SDBaseOps { CustomOp op = (CustomOp)node; extras = op.tArgs(); } else { - extras = node.getExtraArgs() != null ? new double[node.getExtraArgs().length] : new double[0]; + Object[] eArgs = node.getExtraArgs(); + extras = eArgs != null ? new double[eArgs.length] : new double[0]; for (int e = 0; e < extras.length; e++) { - extras[e] = ((Number) node.getExtraArgs()[e]).doubleValue(); + extras[e] = ((Number) eArgs[e]).doubleValue(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 702660d74..19bdb6d55 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -331,6 +331,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class, org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class, org.nd4j.linalg.api.ops.impl.transforms.Constant.class, + org.nd4j.linalg.api.ops.impl.transforms.Histogram.class, org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class, org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class, org.nd4j.linalg.api.ops.impl.transforms.MaxOut.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 75cb5dded..6e5682962 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -48,7 +48,7 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum SDVariable i_v, boolean keepDims, int[] dimensions) { - super(sameDiff,new Object[]{dimensions}); + super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); @@ -70,7 +70,7 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum SDVariable i_v2, boolean keepDims, int[] dimensions) { - super(sameDiff,new Object[]{dimensions}); + super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; f().validateDifferentialFunctionsameDiff(i_v); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 25a04125d..ebf9b9c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -61,7 +61,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { public BaseReduceOp(SameDiff sameDiff, SDVariable i_v, int[] dimensions, boolean keepDims) { - super(sameDiff,new Object[]{dimensions}); + super(sameDiff, null); if (i_v != null) { if(dimensions == null || dimensions.length < 1) dimensions = new int[] {Integer.MAX_VALUE}; @@ -86,7 +86,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { SDVariable i_v, SDVariable i_v2, int[] dimensions, boolean keepDims) { - super(sameDiff,new Object[]{dimensions}); + super(sameDiff,null); if (i_v != null) { if(dimensions == null || dimensions.length < 1) dimensions = new int[] {Integer.MAX_VALUE}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index a35fe8f43..76f50d733 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -2089,7 +2089,9 @@ public class Shape { } // we need to wrap buffer of a current array, to make sure it's properly marked as a View - INDArray ret = Nd4j.create(Nd4j.createBuffer(arr.data(), arr.offset(), arr.length()), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c'); + DataBuffer db = arr.data(); + DataBuffer buffer = Nd4j.createBuffer(db, arr.offset(), arr.length()); + INDArray ret = Nd4j.create(buffer, newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c'); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java index d6de21130..c28f35151 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java @@ -305,7 +305,7 @@ public class NDArrayStrings { } } if (i < l - 1) { - if (!summarize || i < 2 || i > l - 3 || (summarize && l == 6)) { + if (!summarize || i <= 2 || i >= l - 3 || (summarize && l == 6)) { sb.append(colSep); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java index 8007c2d54..5e966f850 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java @@ -17,6 +17,7 @@ package org.nd4j.serde.jackson.shaded; +import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.serde.base64.Nd4jBase64; @@ -76,9 +77,12 @@ public class NDArrayTextSerializer extends JsonSerializer { jg.writeNumber(v); break; case UTF8: - String[] str = new String[(int)arr.length()]; - for( int j=0; j references = new ArrayList<>(); + @Getter protected long numWords = 0; /** @@ -121,6 +123,7 @@ public class Utf8Buffer extends BaseDataBuffer { public Utf8Buffer(DataBuffer underlyingBuffer, long length, long offset) { super(underlyingBuffer, length, offset); + this.numWords = length; } public Utf8Buffer(@NonNull Collection strings) { diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java index 96b154338..65d605e00 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java @@ -87,6 +87,8 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return new BFloat16Buffer(underlyingBuffer, length, offset); } else if (underlyingBuffer.dataType() == DataType.HALF) { return new HalfBuffer(underlyingBuffer, length, offset); + } else if (underlyingBuffer.dataType() == DataType.UTF8) { + return new Utf8Buffer(underlyingBuffer, length, offset); } return null; }