More test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-07 12:28:58 +02:00
parent acdd9c0a8a
commit b8a21bc991
50 changed files with 101 additions and 49 deletions

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray; package org.datavec.api.transform.ndarray;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.MathOp; import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays; import java.util.Arrays;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform { public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName, public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray; package org.datavec.api.transform.ndarray;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.transform.transform.BaseColumnTransform;
@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayMathFunctionTransform extends BaseColumnTransform { public class NDArrayMathFunctionTransform extends BaseColumnTransform {
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD //Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray; package org.datavec.api.transform.ndarray;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathOp; import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData;
@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayScalarOpTransform extends BaseColumnTransform { public class NDArrayScalarOpTransform extends BaseColumnTransform {
private final MathOp mathOp; private final MathOp mathOp;

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.transform.string; package org.datavec.api.transform.transform.string;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@ -31,6 +32,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform { public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
/** /**
* @param columnName The name of the column to convert * @param columnName The name of the column to convert

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class LargestBlobCropTransform extends BaseImageTransform<Mat> { public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
protected org.nd4j.linalg.api.rng.Random rng; protected org.nd4j.linalg.api.rng.Random rng;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
@ -32,6 +33,7 @@ import java.util.*;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class PipelineImageTransform extends BaseImageTransform<Mat> { public class PipelineImageTransform extends BaseImageTransform<Mat> {
protected List<Pair<ImageTransform, Double>> imageTransforms; protected List<Pair<ImageTransform, Double>> imageTransforms;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*;
@JsonIgnoreProperties({"rng", "converter"}) @JsonIgnoreProperties({"rng", "converter"})
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class RandomCropTransform extends BaseImageTransform<Mat> { public class RandomCropTransform extends BaseImageTransform<Mat> {
protected int outputHeight; protected int outputHeight;

View File

@ -20,11 +20,8 @@
package org.nd4j.autodiff.functions; package org.nd4j.autodiff.functions;
import lombok.Data; import lombok.*;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -46,7 +46,9 @@ public class SDVariable implements Serializable {
protected SameDiff sameDiff; protected SameDiff sameDiff;
@Getter @Getter
@Setter
protected String varName; protected String varName;
@Getter @Getter
@Setter @Setter
protected VariableType variableType; protected VariableType variableType;
@ -83,18 +85,6 @@ public class SDVariable implements Serializable {
return varName; return varName;
} }
public void setVarName(String varName) {
this.varName = varName;
}
/**
* @deprecated Use {@link #name()}
*/
@Deprecated
public String getVarName(){
return name();
}
/** /**
* Returns true if this variable is a place holder * Returns true if this variable is a place holder
* @return * @return

View File

@ -39,5 +39,6 @@ public class Variable {
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
protected SDVariable gradient; //Variable corresponding to the gradient of this variable protected SDVariable gradient; //Variable corresponding to the gradient of this variable
@Builder.Default
protected int variableIndex = -1; protected int variableIndex = -1;
} }

View File

@ -76,7 +76,7 @@ public class SameDiffUtils {
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) { public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
" be specified when using external errors: got %s", inputs); " be specified when using external errors: got %s", (Object) inputs);
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients); ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
fn.outputVariable(); fn.outputVariable();
return fn; return fn;

View File

@ -49,7 +49,7 @@ import java.io.Serializable;
import java.util.List; import java.util.List;
@Getter @Getter
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> { public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10; public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;

View File

@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves;
import lombok.Data; import lombok.Data;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.EqualsAndHashCode;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class Histogram extends BaseHistogram { public class Histogram extends BaseHistogram {
private final String title; private final String title;
private final double lower; private final double lower;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.memory.deallocation; package org.nd4j.linalg.api.memory.deallocation;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.memory.Deallocatable; import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.memory.Deallocator;
@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference; import java.lang.ref.WeakReference;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class DeallocatableReference extends WeakReference<Deallocatable> { public class DeallocatableReference extends WeakReference<Deallocatable> {
private String id; private String id;
private Deallocator deallocator; private Deallocator deallocator;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops; package org.nd4j.linalg.api.ops;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.List;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation { public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
protected boolean keepDims = false; protected boolean keepDims = false;

View File

@ -44,6 +44,9 @@ import java.lang.reflect.Array;
import java.util.*; import java.util.*;
@Slf4j @Slf4j
@Builder
@AllArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class DynamicCustomOp extends DifferentialFunction implements CustomOp { public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
private String opName; private String opName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -35,6 +36,7 @@ import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Flatten extends DynamicCustomOp { public class Flatten extends DynamicCustomOp {
private int order; private int order;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import java.util.List; import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@EqualsAndHashCode(callSuper = false)
public abstract class BaseCompatOp extends DynamicCustomOp { public abstract class BaseCompatOp extends DynamicCustomOp {
protected String frameName; protected String frameName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -37,6 +38,7 @@ import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Enter extends BaseCompatOp { public class Enter extends BaseCompatOp {
protected boolean isConstant; protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class While extends BaseCompatOp { public class While extends BaseCompatOp {
protected boolean isConstant; protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -36,6 +37,7 @@ import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class FirstIndex extends BaseIndexAccumulation { public class FirstIndex extends BaseIndexAccumulation {
protected Condition condition; protected Condition condition;
protected double compare; protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -38,6 +39,7 @@ import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LastIndex extends BaseIndexAccumulation { public class LastIndex extends BaseIndexAccumulation {
protected Condition condition; protected Condition condition;
protected double compare; protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmax extends DynamicCustomOp { public class ArgAmax extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmin extends DynamicCustomOp { public class ArgAmin extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgMax extends DynamicCustomOp { public class ArgMax extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgMin extends DynamicCustomOp { public class ArgMin extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder;
import lombok.Data; import lombok.*;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv1DConfig extends BaseConvolutionConfig { public class Conv1DConfig extends BaseConvolutionConfig {
public static final String NCW = "NCW"; public static final String NCW = "NCW";
public static final String NWC = "NWC"; public static final String NWC = "NWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.WeightsFormat; import org.nd4j.enums.WeightsFormat;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv2DConfig extends BaseConvolutionConfig { public class Conv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW"; public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC"; public static final String NHWC = "NHWC";

View File

@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv3DConfig extends BaseConvolutionConfig { public class Conv3DConfig extends BaseConvolutionConfig {
public static final String NDHWC = "NDHWC"; public static final String NDHWC = "NDHWC";
public static final String NCDHW = "NCDHW"; public static final String NCDHW = "NCDHW";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv2DConfig extends BaseConvolutionConfig { public class DeConv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW"; public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC"; public static final String NHWC = "NHWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv3DConfig extends BaseConvolutionConfig { public class DeConv3DConfig extends BaseConvolutionConfig {
public static final String NCDHW = "NCDHW"; public static final String NCDHW = "NCDHW";
public static final String NDHWC = "NDHWC"; public static final String NDHWC = "NDHWC";

View File

@ -24,12 +24,14 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig { public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
private double alpha, beta, bias; private double alpha, beta, bias;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling2DConfig extends BaseConvolutionConfig { public class Pooling2DConfig extends BaseConvolutionConfig {
@Builder.Default private long kH = -1, kW = -1; @Builder.Default private long kH = -1, kW = -1;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling3DConfig extends BaseConvolutionConfig { public class Pooling3DConfig extends BaseConvolutionConfig {
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel @Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides @Builder.Default private long sD = 1, sW = 1, sH = 1; // strides

View File

@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.*; import java.util.*;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class Mmul extends DynamicCustomOp { public class Mmul extends DynamicCustomOp {
protected MMulTranspose mt; protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil;
import java.util.List; import java.util.List;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class MmulBp extends DynamicCustomOp { public class MmulBp extends DynamicCustomOp {
protected MMulTranspose mt; protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class BatchMmul extends DynamicCustomOp { public class BatchMmul extends DynamicCustomOp {
protected int transposeA; protected int transposeA;

View File

@ -39,10 +39,15 @@ import java.util.Map;
public class BalanceMinibatches { public class BalanceMinibatches {
private DataSetIterator dataSetIterator; private DataSetIterator dataSetIterator;
private int numLabels; private int numLabels;
@Builder.Default
private Map<Integer, List<File>> paths = Maps.newHashMap(); private Map<Integer, List<File>> paths = Maps.newHashMap();
@Builder.Default
private int miniBatchSize = -1; private int miniBatchSize = -1;
@Builder.Default
private File rootDir = new File("minibatches"); private File rootDir = new File("minibatches");
@Builder.Default
private File rootSaveDir = new File("minibatchessave"); private File rootSaveDir = new File("minibatchessave");
@Builder.Default
private List<File> labelRootDirs = new ArrayList<>(); private List<File> labelRootDirs = new ArrayList<>();
private DataNormalization dataNormalization; private DataNormalization dataNormalization;

View File

@ -20,7 +20,8 @@
package org.nd4j.linalg.learning.config; package org.nd4j.linalg.learning.config;
import lombok.*; import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaMaxUpdater; import org.nd4j.linalg.learning.AdaMaxUpdater;
import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.GradientUpdater;
@ -44,7 +45,8 @@ public class AdaMax implements IUpdater {
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9; public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999; public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
@lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate @lombok.Builder.Default
private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
private ISchedule learningRateSchedule; private ISchedule learningRateSchedule;
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate @lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate @lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate

View File

@ -335,20 +335,6 @@ public class OpProfiler {
} }
} }
/**
* Dev-time method.
*
* @return
*/
protected StackAggregator getMixedOrderAggregator() {
// FIXME: remove this method, or make it protected
return mixedOrderAggregator;
}
public StackAggregator getScalarAggregator() {
return scalarAggregator;
}
protected void updatePairs(String opName, String opClass) { protected void updatePairs(String opName, String opClass) {
// now we save pairs of ops/classes // now we save pairs of ops/classes
String cOpNameKey = prevOpName + " -> " + opName; String cOpNameKey = prevOpName + " -> " + opName;

View File

@ -25,4 +25,5 @@ dependencies {
implementation "org.slf4j:slf4j-api" implementation "org.slf4j:slf4j-api"
implementation "org.apache.commons:commons-lang3" implementation "org.apache.commons:commons-lang3"
implementation "com.fasterxml.jackson.core:jackson-annotations"
} }

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers; package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.common.config.DL4JClassLoading;
@ -46,6 +47,7 @@ import java.util.List;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> { public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core; package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
@ -39,6 +40,7 @@ import java.util.Map;
*/ */
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class KerasMasking extends KerasLayer { public class KerasMasking extends KerasLayer {
private double maskingValue; private double maskingValue;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core; package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex;
@ -35,6 +36,7 @@ import java.util.Map;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class KerasMerge extends KerasLayer { public class KerasMerge extends KerasLayer {
private final String LAYER_FIELD_MODE = "mode"; private final String LAYER_FIELD_MODE = "mode";

View File

@ -55,11 +55,13 @@ public class VPTree implements Serializable {
private Node root; private Node root;
private String similarityFunction; private String similarityFunction;
@Getter @Getter
@Builder.Default
private boolean invert = false; private boolean invert = false;
private transient ExecutorService executorService; private transient ExecutorService executorService;
@Getter @Getter
@Builder.Default
private int workers = 1; private int workers = 1;
private AtomicInteger size = new AtomicInteger(0); @Builder.Default private AtomicInteger size = new AtomicInteger(0);
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>(); private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();

View File

@ -21,11 +21,13 @@
package org.deeplearning4j.eval.curves; package org.deeplearning4j.eval.curves;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.evaluation.curves.BaseHistogram; import org.nd4j.evaluation.curves.BaseHistogram;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@Deprecated @Deprecated
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class Histogram extends org.nd4j.evaluation.curves.Histogram { public class Histogram extends org.nd4j.evaluation.curves.Histogram {
/** /**

View File

@ -76,8 +76,9 @@ import static com.google.common.base.Preconditions.checkArgument;
@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath", @JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
"trainingMasterUID"}) "trainingMasterUID"})
@EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath", @EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
"trainingMasterUID"}) "trainingMasterUID"}, callSuper = false)
@Slf4j @Slf4j
public class ParameterAveragingTrainingMaster public class ParameterAveragingTrainingMaster
extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker>
implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> { implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.spark.parameterserver.training; package org.deeplearning4j.spark.parameterserver.training;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
@ -100,7 +101,7 @@ import java.util.concurrent.atomic.AtomicInteger;
*/ */
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker> public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker>
implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> { implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
//Static counter/id fields used to determine which training master last set up the singleton param servers, etc //Static counter/id fields used to determine which training master last set up the singleton param servers, etc

View File

@ -53,16 +53,16 @@ public class ParameterServerClient implements NDArrayCallback {
//port to listen on for the subscriber //port to listen on for the subscriber
private int subscriberPort; private int subscriberPort;
//the stream to listen on for the subscriber //the stream to listen on for the subscriber
private int subscriberStream = 11; @Builder.Default private int subscriberStream = 11;
//the "current" ndarray //the "current" ndarray
private AtomicReference<INDArray> arr; private AtomicReference<INDArray> arr;
private INDArray none = Nd4j.scalar(1.0); @Builder.Default private INDArray none = Nd4j.scalar(1.0);
private AtomicBoolean running; private AtomicBoolean running;
private String masterStatusHost; private String masterStatusHost;
private int masterStatusPort; private int masterStatusPort;
private ObjectMapper objectMapper = new ObjectMapper(); @Builder.Default private ObjectMapper objectMapper = new ObjectMapper();
private Aeron aeron; private Aeron aeron;
private boolean compressArray = true; @Builder.Default private boolean compressArray = true;
/** /**
* Tracks number of * Tracks number of

View File

@ -47,7 +47,7 @@ public class TextGenerationLSTM extends ZooModel {
@Builder.Default private long seed = 1234; @Builder.Default private long seed = 1234;
@Builder.Default private int maxLength = 40; @Builder.Default private int maxLength = 40;
@Builder.Default private int totalUniqueCharacters = 47; @Builder.Default private int totalUniqueCharacters = 47;
private int[] inputShape = new int[] {maxLength, totalUniqueCharacters}; @Builder.Default private int[] inputShape = new int[] {maxLength, totalUniqueCharacters};
@Builder.Default private IUpdater updater = new RmsProp(0.01); @Builder.Default private IUpdater updater = new RmsProp(0.01);
@Builder.Default private CacheMode cacheMode = CacheMode.NONE; @Builder.Default private CacheMode cacheMode = CacheMode.NONE;
@Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED; @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;