parent
acdd9c0a8a
commit
b8a21bc991
|
@ -21,6 +21,7 @@
|
|||
package org.datavec.api.transform.ndarray;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.MathOp;
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
|
@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
|||
import java.util.Arrays;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
|
||||
|
||||
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.datavec.api.transform.ndarray;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.datavec.api.transform.MathFunction;
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
|
||||
|
||||
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.datavec.api.transform.ndarray;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.datavec.api.transform.MathOp;
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||
|
@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class NDArrayScalarOpTransform extends BaseColumnTransform {
|
||||
|
||||
private final MathOp mathOp;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.datavec.api.transform.transform.string;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
@ -31,6 +32,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
|
||||
/**
|
||||
* @param columnName The name of the column to convert
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.datavec.image.transform;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
|||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
|
||||
|
||||
protected org.nd4j.linalg.api.rng.Random rng;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.datavec.image.transform;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NonNull;
|
||||
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
|
@ -32,6 +33,7 @@ import java.util.*;
|
|||
import org.bytedeco.opencv.opencv_core.*;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class PipelineImageTransform extends BaseImageTransform<Mat> {
|
||||
|
||||
protected List<Pair<ImageTransform, Double>> imageTransforms;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.datavec.image.transform;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
|||
@JsonIgnoreProperties({"rng", "converter"})
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class RandomCropTransform extends BaseImageTransform<Mat> {
|
||||
|
||||
protected int outputHeight;
|
||||
|
|
|
@ -20,11 +20,8 @@
|
|||
|
||||
package org.nd4j.autodiff.functions;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
|
|
@ -46,7 +46,9 @@ public class SDVariable implements Serializable {
|
|||
protected SameDiff sameDiff;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
protected String varName;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
protected VariableType variableType;
|
||||
|
@ -83,18 +85,6 @@ public class SDVariable implements Serializable {
|
|||
return varName;
|
||||
}
|
||||
|
||||
public void setVarName(String varName) {
|
||||
this.varName = varName;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #name()}
|
||||
*/
|
||||
@Deprecated
|
||||
public String getVarName(){
|
||||
return name();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if this variable is a place holder
|
||||
* @return
|
||||
|
|
|
@ -39,5 +39,6 @@ public class Variable {
|
|||
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
|
||||
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
||||
@Builder.Default
|
||||
protected int variableIndex = -1;
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ public class SameDiffUtils {
|
|||
|
||||
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
|
||||
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
|
||||
" be specified when using external errors: got %s", inputs);
|
||||
" be specified when using external errors: got %s", (Object) inputs);
|
||||
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
|
||||
fn.outputVariable();
|
||||
return fn;
|
||||
|
|
|
@ -49,7 +49,7 @@ import java.io.Serializable;
|
|||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
|
||||
|
||||
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
|
||||
|
|
|
@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves;
|
|||
|
||||
import lombok.Data;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Histogram extends BaseHistogram {
|
||||
private final String title;
|
||||
private final double lower;
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.nd4j.linalg.api.memory.deallocation;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.memory.Deallocatable;
|
||||
import org.nd4j.linalg.api.memory.Deallocator;
|
||||
|
||||
|
@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue;
|
|||
import java.lang.ref.WeakReference;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class DeallocatableReference extends WeakReference<Deallocatable> {
|
||||
private String id;
|
||||
private Deallocator deallocator;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -36,6 +37,7 @@ import java.util.List;
|
|||
|
||||
@Slf4j
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
|
||||
protected boolean keepDims = false;
|
||||
|
||||
|
|
|
@ -44,6 +44,9 @@ import java.lang.reflect.Array;
|
|||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||
|
||||
private String opName;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -35,6 +36,7 @@ import java.util.List;
|
|||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Flatten extends DynamicCustomOp {
|
||||
private int order;
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef;
|
|||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public abstract class BaseCompatOp extends DynamicCustomOp {
|
||||
protected String frameName;
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -37,6 +38,7 @@ import java.util.Map;
|
|||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Enter extends BaseCompatOp {
|
||||
|
||||
protected boolean isConstant;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -36,6 +37,7 @@ import java.util.Map;
|
|||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class While extends BaseCompatOp {
|
||||
|
||||
protected boolean isConstant;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -36,6 +37,7 @@ import java.util.List;
|
|||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class FirstIndex extends BaseIndexAccumulation {
|
||||
protected Condition condition;
|
||||
protected double compare;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -38,6 +39,7 @@ import java.util.Map;
|
|||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class LastIndex extends BaseIndexAccumulation {
|
||||
protected Condition condition;
|
||||
protected double compare;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
@ -38,6 +39,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class ArgAmax extends DynamicCustomOp {
|
||||
protected boolean keepDims = false;
|
||||
private int[] dimensions;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
@ -38,6 +39,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class ArgAmin extends DynamicCustomOp {
|
||||
protected boolean keepDims = false;
|
||||
private int[] dimensions;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
@ -37,6 +38,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class ArgMax extends DynamicCustomOp {
|
||||
protected boolean keepDims = false;
|
||||
private int[] dimensions;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
@ -37,6 +38,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class ArgMin extends DynamicCustomOp {
|
||||
protected boolean keepDims = false;
|
||||
private int[] dimensions;
|
||||
|
|
|
@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
|||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
|
||||
import lombok.*;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Conv1DConfig extends BaseConvolutionConfig {
|
||||
public static final String NCW = "NCW";
|
||||
public static final String NWC = "NWC";
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
|||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.enums.WeightsFormat;
|
||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
|||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Conv2DConfig extends BaseConvolutionConfig {
|
||||
public static final String NCHW = "NCHW";
|
||||
public static final String NHWC = "NHWC";
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
|
|||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
|||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Conv3DConfig extends BaseConvolutionConfig {
|
||||
public static final String NDHWC = "NDHWC";
|
||||
public static final String NCDHW = "NCDHW";
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
|||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
|||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class DeConv2DConfig extends BaseConvolutionConfig {
|
||||
public static final String NCHW = "NCHW";
|
||||
public static final String NHWC = "NHWC";
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
|||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
|||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class DeConv3DConfig extends BaseConvolutionConfig {
|
||||
public static final String NCDHW = "NCDHW";
|
||||
public static final String NDHWC = "NDHWC";
|
||||
|
|
|
@ -24,12 +24,14 @@ import java.util.LinkedHashMap;
|
|||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
||||
|
||||
private double alpha, beta, bias;
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
|||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
|
||||
|
@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
|||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Pooling2DConfig extends BaseConvolutionConfig {
|
||||
|
||||
@Builder.Default private long kH = -1, kW = -1;
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
|||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
|
||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
|||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Pooling3DConfig extends BaseConvolutionConfig {
|
||||
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
|
||||
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
|
||||
|
|
|
@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef;
|
|||
import java.lang.reflect.Field;
|
||||
import java.util.*;
|
||||
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Mmul extends DynamicCustomOp {
|
||||
|
||||
protected MMulTranspose mt;
|
||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class MmulBp extends DynamicCustomOp {
|
||||
|
||||
protected MMulTranspose mt;
|
||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
|
||||
import java.util.*;
|
||||
|
||||
@EqualsAndHashCode
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class BatchMmul extends DynamicCustomOp {
|
||||
|
||||
protected int transposeA;
|
||||
|
|
|
@ -39,10 +39,15 @@ import java.util.Map;
|
|||
public class BalanceMinibatches {
|
||||
private DataSetIterator dataSetIterator;
|
||||
private int numLabels;
|
||||
@Builder.Default
|
||||
private Map<Integer, List<File>> paths = Maps.newHashMap();
|
||||
@Builder.Default
|
||||
private int miniBatchSize = -1;
|
||||
@Builder.Default
|
||||
private File rootDir = new File("minibatches");
|
||||
@Builder.Default
|
||||
private File rootSaveDir = new File("minibatchessave");
|
||||
@Builder.Default
|
||||
private List<File> labelRootDirs = new ArrayList<>();
|
||||
private DataNormalization dataNormalization;
|
||||
|
||||
|
|
|
@ -20,7 +20,8 @@
|
|||
|
||||
package org.nd4j.linalg.learning.config;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.learning.AdaMaxUpdater;
|
||||
import org.nd4j.linalg.learning.GradientUpdater;
|
||||
|
@ -44,7 +45,8 @@ public class AdaMax implements IUpdater {
|
|||
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
|
||||
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
|
||||
|
||||
@lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
|
||||
@lombok.Builder.Default
|
||||
private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
|
||||
private ISchedule learningRateSchedule;
|
||||
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
|
||||
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate
|
||||
|
|
|
@ -335,20 +335,6 @@ public class OpProfiler {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dev-time method.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
protected StackAggregator getMixedOrderAggregator() {
|
||||
// FIXME: remove this method, or make it protected
|
||||
return mixedOrderAggregator;
|
||||
}
|
||||
|
||||
public StackAggregator getScalarAggregator() {
|
||||
return scalarAggregator;
|
||||
}
|
||||
|
||||
protected void updatePairs(String opName, String opClass) {
|
||||
// now we save pairs of ops/classes
|
||||
String cOpNameKey = prevOpName + " -> " + opName;
|
||||
|
|
|
@ -25,4 +25,5 @@ dependencies {
|
|||
|
||||
implementation "org.slf4j:slf4j-api"
|
||||
implementation "org.apache.commons:commons-lang3"
|
||||
implementation "com.fasterxml.jackson.core:jackson-annotations"
|
||||
}
|
|
@ -21,6 +21,7 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||
|
@ -46,6 +47,7 @@ import java.util.List;
|
|||
|
||||
@Slf4j
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||
|
@ -39,6 +40,7 @@ import java.util.Map;
|
|||
*/
|
||||
@Slf4j
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class KerasMasking extends KerasLayer {
|
||||
|
||||
private double maskingValue;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||
|
@ -35,6 +36,7 @@ import java.util.Map;
|
|||
|
||||
@Slf4j
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class KerasMerge extends KerasLayer {
|
||||
|
||||
private final String LAYER_FIELD_MODE = "mode";
|
||||
|
|
|
@ -55,11 +55,13 @@ public class VPTree implements Serializable {
|
|||
private Node root;
|
||||
private String similarityFunction;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private boolean invert = false;
|
||||
private transient ExecutorService executorService;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int workers = 1;
|
||||
private AtomicInteger size = new AtomicInteger(0);
|
||||
@Builder.Default private AtomicInteger size = new AtomicInteger(0);
|
||||
|
||||
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();
|
||||
|
||||
|
|
|
@ -21,11 +21,13 @@
|
|||
package org.deeplearning4j.eval.curves;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.evaluation.curves.BaseHistogram;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
@Deprecated
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class Histogram extends org.nd4j.evaluation.curves.Histogram {
|
||||
|
||||
/**
|
||||
|
|
|
@ -76,8 +76,9 @@ import static com.google.common.base.Preconditions.checkArgument;
|
|||
@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
|
||||
"trainingMasterUID"})
|
||||
@EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
|
||||
"trainingMasterUID"})
|
||||
"trainingMasterUID"}, callSuper = false)
|
||||
@Slf4j
|
||||
|
||||
public class ParameterAveragingTrainingMaster
|
||||
extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker>
|
||||
implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.deeplearning4j.spark.parameterserver.training;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
|
@ -100,7 +101,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
*/
|
||||
@Slf4j
|
||||
@Data
|
||||
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker>
|
||||
implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
|
||||
//Static counter/id fields used to determine which training master last set up the singleton param servers, etc
|
||||
|
|
|
@ -53,16 +53,16 @@ public class ParameterServerClient implements NDArrayCallback {
|
|||
//port to listen on for the subscriber
|
||||
private int subscriberPort;
|
||||
//the stream to listen on for the subscriber
|
||||
private int subscriberStream = 11;
|
||||
@Builder.Default private int subscriberStream = 11;
|
||||
//the "current" ndarray
|
||||
private AtomicReference<INDArray> arr;
|
||||
private INDArray none = Nd4j.scalar(1.0);
|
||||
@Builder.Default private INDArray none = Nd4j.scalar(1.0);
|
||||
private AtomicBoolean running;
|
||||
private String masterStatusHost;
|
||||
private int masterStatusPort;
|
||||
private ObjectMapper objectMapper = new ObjectMapper();
|
||||
@Builder.Default private ObjectMapper objectMapper = new ObjectMapper();
|
||||
private Aeron aeron;
|
||||
private boolean compressArray = true;
|
||||
@Builder.Default private boolean compressArray = true;
|
||||
|
||||
/**
|
||||
* Tracks number of
|
||||
|
|
|
@ -47,7 +47,7 @@ public class TextGenerationLSTM extends ZooModel {
|
|||
@Builder.Default private long seed = 1234;
|
||||
@Builder.Default private int maxLength = 40;
|
||||
@Builder.Default private int totalUniqueCharacters = 47;
|
||||
private int[] inputShape = new int[] {maxLength, totalUniqueCharacters};
|
||||
@Builder.Default private int[] inputShape = new int[] {maxLength, totalUniqueCharacters};
|
||||
@Builder.Default private IUpdater updater = new RmsProp(0.01);
|
||||
@Builder.Default private CacheMode cacheMode = CacheMode.NONE;
|
||||
@Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
|
||||
|
|
Loading…
Reference in New Issue