More test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-07 12:28:58 +02:00
parent acdd9c0a8a
commit b8a21bc991
50 changed files with 101 additions and 49 deletions

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
@Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty;
@Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty;
@Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayScalarOpTransform extends BaseColumnTransform {
private final MathOp mathOp;

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.transform.string;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import com.fasterxml.jackson.annotation.JsonProperty;
@ -31,6 +32,7 @@ import java.util.Collections;
import java.util.List;
@Data
@EqualsAndHashCode(callSuper = false)
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
/**
* @param columnName The name of the column to convert

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j;
@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
@Data
@EqualsAndHashCode(callSuper = false)
public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
protected org.nd4j.linalg.api.rng.Random rng;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.datavec.image.data.ImageWritable;
@ -32,6 +33,7 @@ import java.util.*;
import org.bytedeco.opencv.opencv_core.*;
@Data
@EqualsAndHashCode(callSuper = false)
public class PipelineImageTransform extends BaseImageTransform<Mat> {
protected List<Pair<ImageTransform, Double>> imageTransforms;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j;
@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*;
@JsonIgnoreProperties({"rng", "converter"})
@JsonInclude(JsonInclude.Include.NON_NULL)
@Data
@EqualsAndHashCode(callSuper = false)
public class RandomCropTransform extends BaseImageTransform<Mat> {
protected int outputHeight;

View File

@ -20,11 +20,8 @@
package org.nd4j.autodiff.functions;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -46,7 +46,9 @@ public class SDVariable implements Serializable {
protected SameDiff sameDiff;
@Getter
@Setter
protected String varName;
@Getter
@Setter
protected VariableType variableType;
@ -83,18 +85,6 @@ public class SDVariable implements Serializable {
return varName;
}
public void setVarName(String varName) {
this.varName = varName;
}
/**
* @deprecated Use {@link #name()}
*/
@Deprecated
public String getVarName(){
return name();
}
/**
* Returns true if this variable is a place holder
* @return

View File

@ -39,5 +39,6 @@ public class Variable {
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
@Builder.Default
protected int variableIndex = -1;
}

View File

@ -76,7 +76,7 @@ public class SameDiffUtils {
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
" be specified when using external errors: got %s", inputs);
" be specified when using external errors: got %s", (Object) inputs);
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
fn.outputVariable();
return fn;

View File

@ -49,7 +49,7 @@ import java.io.Serializable;
import java.util.List;
@Getter
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;

View File

@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves;
import lombok.Data;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = false)
public class Histogram extends BaseHistogram {
private final String title;
private final double lower;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.memory.deallocation;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator;
@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
@Data
@EqualsAndHashCode(callSuper = false)
public class DeallocatableReference extends WeakReference<Deallocatable> {
private String id;
private Deallocator deallocator;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.List;
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
protected boolean keepDims = false;

View File

@ -44,6 +44,9 @@ import java.lang.reflect.Array;
import java.util.*;
@Slf4j
@Builder
@AllArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
private String opName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
@ -35,6 +36,7 @@ import java.util.List;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Flatten extends DynamicCustomOp {
private int order;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef;
import java.util.HashMap;
import java.util.Map;
@EqualsAndHashCode(callSuper = false)
public abstract class BaseCompatOp extends DynamicCustomOp {
protected String frameName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -37,6 +38,7 @@ import java.util.Map;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Enter extends BaseCompatOp {
protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.Map;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class While extends BaseCompatOp {
protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
@ -36,6 +37,7 @@ import java.util.List;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class FirstIndex extends BaseIndexAccumulation {
protected Condition condition;
protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
@ -38,6 +39,7 @@ import java.util.Map;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LastIndex extends BaseIndexAccumulation {
protected Condition condition;
protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmax extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmin extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgMax extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgMin extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv1DConfig extends BaseConvolutionConfig {
public static final String NCW = "NCW";
public static final String NWC = "NWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.WeightsFormat;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC";

View File

@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv3DConfig extends BaseConvolutionConfig {
public static final String NDHWC = "NDHWC";
public static final String NCDHW = "NCDHW";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv3DConfig extends BaseConvolutionConfig {
public static final String NCDHW = "NCDHW";
public static final String NDHWC = "NDHWC";

View File

@ -24,12 +24,14 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
private double alpha, beta, bias;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling2DConfig extends BaseConvolutionConfig {
@Builder.Default private long kH = -1, kW = -1;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling3DConfig extends BaseConvolutionConfig {
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides

View File

@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.*;
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class Mmul extends DynamicCustomOp {
protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil;
import java.util.List;
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class MmulBp extends DynamicCustomOp {
protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class BatchMmul extends DynamicCustomOp {
protected int transposeA;

View File

@ -39,10 +39,15 @@ import java.util.Map;
public class BalanceMinibatches {
private DataSetIterator dataSetIterator;
private int numLabels;
@Builder.Default
private Map<Integer, List<File>> paths = Maps.newHashMap();
@Builder.Default
private int miniBatchSize = -1;
@Builder.Default
private File rootDir = new File("minibatches");
@Builder.Default
private File rootSaveDir = new File("minibatchessave");
@Builder.Default
private List<File> labelRootDirs = new ArrayList<>();
private DataNormalization dataNormalization;

View File

@ -20,7 +20,8 @@
package org.nd4j.linalg.learning.config;
import lombok.*;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaMaxUpdater;
import org.nd4j.linalg.learning.GradientUpdater;
@ -44,7 +45,8 @@ public class AdaMax implements IUpdater {
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
@lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
@lombok.Builder.Default
private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
private ISchedule learningRateSchedule;
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate

View File

@ -335,20 +335,6 @@ public class OpProfiler {
}
}
/**
* Dev-time method.
*
* @return
*/
protected StackAggregator getMixedOrderAggregator() {
// FIXME: remove this method, or make it protected
return mixedOrderAggregator;
}
public StackAggregator getScalarAggregator() {
return scalarAggregator;
}
protected void updatePairs(String opName, String opClass) {
// now we save pairs of ops/classes
String cOpNameKey = prevOpName + " -> " + opName;

View File

@ -25,4 +25,5 @@ dependencies {
implementation "org.slf4j:slf4j-api"
implementation "org.apache.commons:commons-lang3"
implementation "com.fasterxml.jackson.core:jackson-annotations"
}

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.common.config.DL4JClassLoading;
@ -46,6 +47,7 @@ import java.util.List;
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
@ -39,6 +40,7 @@ import java.util.Map;
*/
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public class KerasMasking extends KerasLayer {
private double maskingValue;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
@ -35,6 +36,7 @@ import java.util.Map;
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public class KerasMerge extends KerasLayer {
private final String LAYER_FIELD_MODE = "mode";

View File

@ -55,11 +55,13 @@ public class VPTree implements Serializable {
private Node root;
private String similarityFunction;
@Getter
@Builder.Default
private boolean invert = false;
private transient ExecutorService executorService;
@Getter
@Builder.Default
private int workers = 1;
private AtomicInteger size = new AtomicInteger(0);
@Builder.Default private AtomicInteger size = new AtomicInteger(0);
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();

View File

@ -21,11 +21,13 @@
package org.deeplearning4j.eval.curves;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.evaluation.curves.BaseHistogram;
import com.fasterxml.jackson.annotation.JsonProperty;
@Deprecated
@Data
@EqualsAndHashCode(callSuper = false)
public class Histogram extends org.nd4j.evaluation.curves.Histogram {
/**

View File

@ -76,8 +76,9 @@ import static com.google.common.base.Preconditions.checkArgument;
@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
"trainingMasterUID"})
@EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
"trainingMasterUID"})
"trainingMasterUID"}, callSuper = false)
@Slf4j
public class ParameterAveragingTrainingMaster
extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker>
implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.spark.parameterserver.training;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
@ -100,7 +101,7 @@ import java.util.concurrent.atomic.AtomicInteger;
*/
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker>
implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
//Static counter/id fields used to determine which training master last set up the singleton param servers, etc

View File

@ -53,16 +53,16 @@ public class ParameterServerClient implements NDArrayCallback {
//port to listen on for the subscriber
private int subscriberPort;
//the stream to listen on for the subscriber
private int subscriberStream = 11;
@Builder.Default private int subscriberStream = 11;
//the "current" ndarray
private AtomicReference<INDArray> arr;
private INDArray none = Nd4j.scalar(1.0);
@Builder.Default private INDArray none = Nd4j.scalar(1.0);
private AtomicBoolean running;
private String masterStatusHost;
private int masterStatusPort;
private ObjectMapper objectMapper = new ObjectMapper();
@Builder.Default private ObjectMapper objectMapper = new ObjectMapper();
private Aeron aeron;
private boolean compressArray = true;
@Builder.Default private boolean compressArray = true;
/**
* Tracks number of

View File

@ -47,7 +47,7 @@ public class TextGenerationLSTM extends ZooModel {
@Builder.Default private long seed = 1234;
@Builder.Default private int maxLength = 40;
@Builder.Default private int totalUniqueCharacters = 47;
private int[] inputShape = new int[] {maxLength, totalUniqueCharacters};
@Builder.Default private int[] inputShape = new int[] {maxLength, totalUniqueCharacters};
@Builder.Default private IUpdater updater = new RmsProp(0.01);
@Builder.Default private CacheMode cacheMode = CacheMode.NONE;
@Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;