Various fixes (#65)
* #7977 deprecate legacy MultiLayerNetwork/ComputationGraph.params(boolean) method Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix bad test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Histogram mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix incorrect name handling in DifferentialFunction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Histogram fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Proper histogram fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * ToString/NDArrayStrings fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * JSON UTF8 serialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
c499dc962f
commit
d94bc7257c
|
@ -401,6 +401,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
|
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
|
||||||
.nIn(convNIn).nOut(convNOut).hasBias(false)
|
.nIn(convNIn).nOut(convNOut).hasBias(false)
|
||||||
|
|
|
@ -996,10 +996,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
int[] mbSizes = new int[] {1, 3, 10};
|
int[] mbSizes = new int[] {1, 3, 10};
|
||||||
for (int minibatch : mbSizes) {
|
for (int minibatch : mbSizes) {
|
||||||
|
|
||||||
INDArray in1 = Nd4j.rand(minibatch, 2);
|
INDArray in1 = Nd4j.rand(DataType.DOUBLE, minibatch, 2);
|
||||||
INDArray in2 = Nd4j.rand(minibatch, 2);
|
INDArray in2 = Nd4j.rand(DataType.DOUBLE, minibatch, 2);
|
||||||
|
|
||||||
INDArray labels = Nd4j.rand(minibatch, 1);
|
INDArray labels = Nd4j.rand(DataType.DOUBLE, minibatch, 1);
|
||||||
|
|
||||||
String testName = "testBasicL2() - minibatch = " + minibatch;
|
String testName = "testBasicL2() - minibatch = " + minibatch;
|
||||||
|
|
||||||
|
|
|
@ -28,13 +28,12 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
|
@ -451,10 +450,10 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
//KL divergence: should be a probability distribution for labels??
|
//KL divergence: should be a probability distribution for labels??
|
||||||
ret[1] = Nd4j.rand(labelsShape);
|
ret[1] = Nd4j.rand(labelsShape);
|
||||||
if(labelsShape.length == 2){
|
if(labelsShape.length == 2){
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(ret[1]));
|
Nd4j.getExecutioner().exec(new SoftMax(ret[1]));
|
||||||
} else if(labelsShape.length == 3) {
|
} else if(labelsShape.length == 3) {
|
||||||
for (int i = 0; i < labelsShape[2]; i++) {
|
for (int i = 0; i < labelsShape[2]; i++) {
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(ret[1].get(all(), all(), point(i))));
|
Nd4j.getExecutioner().exec(new SoftMax(ret[1].get(all(), all(), point(i))));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
|
|
@ -310,8 +310,8 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list()
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list()
|
||||||
.layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).build())
|
.layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).build())
|
||||||
.layer(new OutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build())
|
.layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build())
|
||||||
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
|
.setInputType(InputType.recurrent(nClassesIn))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -324,7 +324,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
int batchSize = 3;
|
int batchSize = 3;
|
||||||
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
INDArray inEmbedding = Nd4j.create(batchSize, 1);
|
||||||
INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, 1);
|
INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, 1);
|
||||||
INDArray outLabels = Nd4j.create(batchSize, 4);
|
INDArray outLabels = Nd4j.create(batchSize, 4, 1);
|
||||||
|
|
||||||
Random r = new Random(1337);
|
Random r = new Random(1337);
|
||||||
for (int i = 0; i < batchSize; i++) {
|
for (int i = 0; i < batchSize; i++) {
|
||||||
|
@ -333,7 +333,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
inOneHot.putScalar(new int[]{i, classIdx, 0}, 1.0);
|
inOneHot.putScalar(new int[]{i, classIdx, 0}, 1.0);
|
||||||
|
|
||||||
int labelIdx = r.nextInt(4);
|
int labelIdx = r.nextInt(4);
|
||||||
outLabels.putScalar(new int[]{i, labelIdx}, 1.0);
|
outLabels.putScalar(new int[]{i, labelIdx, 0}, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
net.setInput(inEmbedding);
|
net.setInput(inEmbedding);
|
||||||
|
|
|
@ -2892,26 +2892,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the parameters for the ComputationGraph
|
* @deprecated To be removed. Use {@link #params()}
|
||||||
*
|
|
||||||
* @param backwardOnly If true: backprop parameters only (i.e., no visible layer biases used in layerwise pretraining layers)
|
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public INDArray params(boolean backwardOnly) {
|
public INDArray params(boolean backwardOnly) {
|
||||||
if (backwardOnly)
|
return params();
|
||||||
return flattenedParams;
|
|
||||||
|
|
||||||
List<INDArray> list = new ArrayList<>(layers.length);
|
|
||||||
for (int i = 0; i < topologicalOrder.length; i++) {
|
|
||||||
if (!vertices[topologicalOrder[i]].hasLayer())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
Layer l = vertices[topologicalOrder[i]].getLayer();
|
|
||||||
INDArray layerParams = l.params();
|
|
||||||
if (layerParams != null)
|
|
||||||
list.add(layerParams); //may be null: subsampling etc layers
|
|
||||||
}
|
|
||||||
|
|
||||||
return Nd4j.toFlattened('f', list);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3183,7 +3168,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray params() {
|
public INDArray params() {
|
||||||
return params(true);
|
return flattenedParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1433,29 +1433,11 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a 1 x m vector where the vector is composed of a flattened vector of all of the parameters (weights and
|
* @deprecated To be removed. Use {@link #params()} instead
|
||||||
* biases etc) for all parameters in the network. Note that this method is generally reserved for developer and
|
|
||||||
* internal use - see {@link #getParam(String)} and {@link #paramTable()} for a more useful/interpretable
|
|
||||||
* representation of the parameters.<br>
|
|
||||||
* Note that with backwardsOnly = false the parameter vector is not a copy, and changes to the returned INDArray
|
|
||||||
* will impact the network parameters.
|
|
||||||
*
|
|
||||||
* @param backwardOnly Return a copy of the parameters excluding any parameters used only for unsupervised layers'
|
|
||||||
* unsupervised training (such as decoder parameters in an autoencoder layer
|
|
||||||
* @return the params for this neural net
|
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public INDArray params(boolean backwardOnly) {
|
public INDArray params(boolean backwardOnly) {
|
||||||
if (backwardOnly)
|
return params();
|
||||||
return params();
|
|
||||||
|
|
||||||
List<INDArray> params = new ArrayList<>();
|
|
||||||
for (Layer layer : getLayers()) {
|
|
||||||
INDArray layerParams = layer.params();
|
|
||||||
if (layerParams != null)
|
|
||||||
params.add(layerParams); //may be null: subsampling etc layers
|
|
||||||
}
|
|
||||||
|
|
||||||
return Nd4j.toFlattened('f', params);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,10 +41,6 @@ public class WeightInitUtil {
|
||||||
|
|
||||||
private WeightInitUtil() {}
|
private WeightInitUtil() {}
|
||||||
|
|
||||||
public static INDArray initWeights(int[] shape, float min, float max) {
|
|
||||||
return Nd4j.rand(shape, min, max, Nd4j.getRandom());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes a matrix with the given weight initialization scheme.
|
* Initializes a matrix with the given weight initialization scheme.
|
||||||
|
|
|
@ -28,7 +28,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
|
||||||
import org.deeplearning4j.ui.stats.api.*;
|
import org.deeplearning4j.ui.stats.api.*;
|
||||||
import org.deeplearning4j.ui.stats.impl.DefaultStatsInitializationConfiguration;
|
import org.deeplearning4j.ui.stats.impl.DefaultStatsInitializationConfiguration;
|
||||||
import org.deeplearning4j.ui.stats.impl.DefaultStatsUpdateConfiguration;
|
import org.deeplearning4j.ui.stats.impl.DefaultStatsUpdateConfiguration;
|
||||||
|
@ -763,11 +762,11 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
|
||||||
|
|
||||||
for (Map.Entry<String, INDArray> entry : map.entrySet()) {
|
for (Map.Entry<String, INDArray> entry : map.entrySet()) {
|
||||||
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram hOp =
|
org.nd4j.linalg.api.ops.impl.transforms.Histogram hOp =
|
||||||
new org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram(entry.getValue(), nBins);
|
new org.nd4j.linalg.api.ops.impl.transforms.Histogram(entry.getValue(), nBins);
|
||||||
Nd4j.getExecutioner().exec(hOp);
|
Nd4j.exec(hOp);
|
||||||
|
|
||||||
INDArray bins = hOp.z();
|
INDArray bins = hOp.getOutputArgument(0);
|
||||||
int[] count = new int[nBins];
|
int[] count = new int[nBins];
|
||||||
for (int i = 0; i < bins.length(); i++) {
|
for (int i = 0; i < bins.length(); i++) {
|
||||||
count[i] = (int) bins.getDouble(i);
|
count[i] = (int) bins.getDouble(i);
|
||||||
|
|
|
@ -34,6 +34,7 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length")
|
REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length")
|
||||||
|
|
||||||
|
output->nullify();
|
||||||
helpers::histogramHelper(block.launchContext(), *input, *output);
|
helpers::histogramHelper(block.launchContext(), *input, *output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -663,9 +663,9 @@ public abstract class DifferentialFunction {
|
||||||
scope = "";
|
scope = "";
|
||||||
else
|
else
|
||||||
scope = scope + "/";
|
scope = scope + "/";
|
||||||
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex).replace(":", "_");
|
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex);
|
||||||
while(sameDiff.functionExists(varName)) {
|
while(sameDiff.functionExists(varName)) {
|
||||||
varName = scope + sameDiff.generateNewVarName(opName(), argIndex).replace(":", "_");
|
varName = scope + sameDiff.generateNewVarName(opName(), argIndex);
|
||||||
argIndex++;
|
argIndex++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4589,9 +4589,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
CustomOp op = (CustomOp)node;
|
CustomOp op = (CustomOp)node;
|
||||||
extras = op.tArgs();
|
extras = op.tArgs();
|
||||||
} else {
|
} else {
|
||||||
extras = node.getExtraArgs() != null ? new double[node.getExtraArgs().length] : new double[0];
|
Object[] eArgs = node.getExtraArgs();
|
||||||
|
extras = eArgs != null ? new double[eArgs.length] : new double[0];
|
||||||
for (int e = 0; e < extras.length; e++) {
|
for (int e = 0; e < extras.length; e++) {
|
||||||
extras[e] = ((Number) node.getExtraArgs()[e]).doubleValue();
|
extras[e] = ((Number) eArgs[e]).doubleValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -331,6 +331,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class,
|
org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class,
|
org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.Constant.class,
|
org.nd4j.linalg.api.ops.impl.transforms.Constant.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.Histogram.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class,
|
org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class,
|
org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.MaxOut.class,
|
org.nd4j.linalg.api.ops.impl.transforms.MaxOut.class,
|
||||||
|
|
|
@ -48,7 +48,7 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
|
||||||
SDVariable i_v,
|
SDVariable i_v,
|
||||||
boolean keepDims,
|
boolean keepDims,
|
||||||
int[] dimensions) {
|
int[] dimensions) {
|
||||||
super(sameDiff,new Object[]{dimensions});
|
super(sameDiff,null);
|
||||||
if (i_v != null) {
|
if (i_v != null) {
|
||||||
this.dimensions = dimensions;
|
this.dimensions = dimensions;
|
||||||
f().validateDifferentialFunctionsameDiff(i_v);
|
f().validateDifferentialFunctionsameDiff(i_v);
|
||||||
|
@ -70,7 +70,7 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
|
||||||
SDVariable i_v2,
|
SDVariable i_v2,
|
||||||
boolean keepDims,
|
boolean keepDims,
|
||||||
int[] dimensions) {
|
int[] dimensions) {
|
||||||
super(sameDiff,new Object[]{dimensions});
|
super(sameDiff,null);
|
||||||
if (i_v != null) {
|
if (i_v != null) {
|
||||||
this.dimensions = dimensions;
|
this.dimensions = dimensions;
|
||||||
f().validateDifferentialFunctionsameDiff(i_v);
|
f().validateDifferentialFunctionsameDiff(i_v);
|
||||||
|
|
|
@ -61,7 +61,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
||||||
public BaseReduceOp(SameDiff sameDiff,
|
public BaseReduceOp(SameDiff sameDiff,
|
||||||
SDVariable i_v,
|
SDVariable i_v,
|
||||||
int[] dimensions, boolean keepDims) {
|
int[] dimensions, boolean keepDims) {
|
||||||
super(sameDiff,new Object[]{dimensions});
|
super(sameDiff, null);
|
||||||
if (i_v != null) {
|
if (i_v != null) {
|
||||||
if(dimensions == null || dimensions.length < 1)
|
if(dimensions == null || dimensions.length < 1)
|
||||||
dimensions = new int[] {Integer.MAX_VALUE};
|
dimensions = new int[] {Integer.MAX_VALUE};
|
||||||
|
@ -86,7 +86,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
||||||
SDVariable i_v,
|
SDVariable i_v,
|
||||||
SDVariable i_v2,
|
SDVariable i_v2,
|
||||||
int[] dimensions, boolean keepDims) {
|
int[] dimensions, boolean keepDims) {
|
||||||
super(sameDiff,new Object[]{dimensions});
|
super(sameDiff,null);
|
||||||
if (i_v != null) {
|
if (i_v != null) {
|
||||||
if(dimensions == null || dimensions.length < 1)
|
if(dimensions == null || dimensions.length < 1)
|
||||||
dimensions = new int[] {Integer.MAX_VALUE};
|
dimensions = new int[] {Integer.MAX_VALUE};
|
||||||
|
|
|
@ -2089,7 +2089,9 @@ public class Shape {
|
||||||
}
|
}
|
||||||
|
|
||||||
// we need to wrap buffer of a current array, to make sure it's properly marked as a View
|
// we need to wrap buffer of a current array, to make sure it's properly marked as a View
|
||||||
INDArray ret = Nd4j.create(Nd4j.createBuffer(arr.data(), arr.offset(), arr.length()), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
|
DataBuffer db = arr.data();
|
||||||
|
DataBuffer buffer = Nd4j.createBuffer(db, arr.offset(), arr.length());
|
||||||
|
INDArray ret = Nd4j.create(buffer, newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -305,7 +305,7 @@ public class NDArrayStrings {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (i < l - 1) {
|
if (i < l - 1) {
|
||||||
if (!summarize || i < 2 || i > l - 3 || (summarize && l == 6)) {
|
if (!summarize || i <= 2 || i >= l - 3 || (summarize && l == 6)) {
|
||||||
sb.append(colSep);
|
sb.append(colSep);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.serde.jackson.shaded;
|
package org.nd4j.serde.jackson.shaded;
|
||||||
|
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
import org.nd4j.serde.base64.Nd4jBase64;
|
||||||
|
@ -76,9 +77,12 @@ public class NDArrayTextSerializer extends JsonSerializer<INDArray> {
|
||||||
jg.writeNumber(v);
|
jg.writeNumber(v);
|
||||||
break;
|
break;
|
||||||
case UTF8:
|
case UTF8:
|
||||||
String[] str = new String[(int)arr.length()];
|
Utf8Buffer utf8B = ((Utf8Buffer)arr.data());
|
||||||
for( int j=0; j<str.length; j++ )
|
long n = utf8B.getNumWords();
|
||||||
jg.writeString(arr.getString(j));
|
for( int j=0; j<n; j++ ) {
|
||||||
|
String s = utf8B.getString(j);
|
||||||
|
jg.writeString(s);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case COMPRESSED:
|
case COMPRESSED:
|
||||||
case UNKNOWN:
|
case UNKNOWN:
|
||||||
|
|
|
@ -68,7 +68,7 @@ public class TestOpMapping extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
String opName = df.opName();
|
String opName = df.opName();
|
||||||
|
|
||||||
assertTrue(opName, opNameMapping.containsKey(opName));
|
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName, opNameMapping.containsKey(opName));
|
||||||
|
|
||||||
try{
|
try{
|
||||||
String[] tfNames = df.tensorflowNames();
|
String[] tfNames = df.tensorflowNames();
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class ToStringTest extends BaseNd4jTest {
|
||||||
assertEquals("[ 1, 2, 3]",
|
assertEquals("[ 1, 2, 3]",
|
||||||
Nd4j.createFromArray(1, 2, 3).toString());
|
Nd4j.createFromArray(1, 2, 3).toString());
|
||||||
|
|
||||||
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
|
assertEquals("[ 1, 2, 3, 4, 5, 6, 7, 8]",
|
||||||
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(1000, false, 2));
|
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(1000, false, 2));
|
||||||
|
|
||||||
assertEquals("[ 1.132, 2.644, 3.234]",
|
assertEquals("[ 1.132, 2.644, 3.234]",
|
||||||
|
@ -53,9 +53,8 @@ public class ToStringTest extends BaseNd4jTest {
|
||||||
assertEquals("[ 1.132414, 2.64356456, 3.25345234]",
|
assertEquals("[ 1.132414, 2.64356456, 3.25345234]",
|
||||||
Nd4j.createFromArray(1.132414, 2.64356456, 3.25345234).toStringFull());
|
Nd4j.createFromArray(1.132414, 2.64356456, 3.25345234).toStringFull());
|
||||||
|
|
||||||
assertEquals("[ 1, 2, 3 ... 6 7, 8]",
|
assertEquals("[ 1, 2, 3, ... 6, 7, 8]",
|
||||||
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(100, true, 1));
|
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(6, true, 1));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class JsonSerdeTests extends BaseNd4jTest {
|
||||||
|
|
||||||
INDArray arr;
|
INDArray arr;
|
||||||
if(dt == DataType.UTF8){
|
if(dt == DataType.UTF8){
|
||||||
arr = Nd4j.create("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l").reshape('c', 3, 4);
|
arr = Nd4j.create("aaaaa", "bbbb", "ccc", "dd", "e", "f", "g", "h", "i", "j", "k", "l").reshape('c', 3, 4);
|
||||||
} else {
|
} else {
|
||||||
arr = in.castTo(dt);
|
arr = in.castTo(dt);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.buffer;
|
package org.nd4j.linalg.api.buffer;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.BytePointer;
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
@ -42,6 +43,7 @@ public class Utf8Buffer extends BaseDataBuffer {
|
||||||
|
|
||||||
protected Collection<Pointer> references = new ArrayList<>();
|
protected Collection<Pointer> references = new ArrayList<>();
|
||||||
|
|
||||||
|
@Getter
|
||||||
protected long numWords = 0;
|
protected long numWords = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -121,6 +123,7 @@ public class Utf8Buffer extends BaseDataBuffer {
|
||||||
|
|
||||||
public Utf8Buffer(DataBuffer underlyingBuffer, long length, long offset) {
|
public Utf8Buffer(DataBuffer underlyingBuffer, long length, long offset) {
|
||||||
super(underlyingBuffer, length, offset);
|
super(underlyingBuffer, length, offset);
|
||||||
|
this.numWords = length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Utf8Buffer(@NonNull Collection<String> strings) {
|
public Utf8Buffer(@NonNull Collection<String> strings) {
|
||||||
|
|
|
@ -87,6 +87,8 @@ public class DefaultDataBufferFactory implements DataBufferFactory {
|
||||||
return new BFloat16Buffer(underlyingBuffer, length, offset);
|
return new BFloat16Buffer(underlyingBuffer, length, offset);
|
||||||
} else if (underlyingBuffer.dataType() == DataType.HALF) {
|
} else if (underlyingBuffer.dataType() == DataType.HALF) {
|
||||||
return new HalfBuffer(underlyingBuffer, length, offset);
|
return new HalfBuffer(underlyingBuffer, length, offset);
|
||||||
|
} else if (underlyingBuffer.dataType() == DataType.UTF8) {
|
||||||
|
return new Utf8Buffer(underlyingBuffer, length, offset);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue