parent
acdd9c0a8a
commit
b8a21bc991
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.ndarray;
|
package org.datavec.api.transform.ndarray;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.MathOp;
|
import org.datavec.api.transform.MathOp;
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
|
@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
|
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
|
||||||
|
|
||||||
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,
|
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.ndarray;
|
package org.datavec.api.transform.ndarray;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.MathFunction;
|
import org.datavec.api.transform.MathFunction;
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
import org.datavec.api.transform.transform.BaseColumnTransform;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
|
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
|
||||||
|
|
||||||
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD
|
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.ndarray;
|
package org.datavec.api.transform.ndarray;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.MathOp;
|
import org.datavec.api.transform.MathOp;
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||||
|
@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class NDArrayScalarOpTransform extends BaseColumnTransform {
|
public class NDArrayScalarOpTransform extends BaseColumnTransform {
|
||||||
|
|
||||||
private final MathOp mathOp;
|
private final MathOp mathOp;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.transform.string;
|
package org.datavec.api.transform.transform.string;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
@ -31,6 +32,7 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
|
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
|
||||||
/**
|
/**
|
||||||
* @param columnName The name of the column to convert
|
* @param columnName The name of the column to convert
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
|
public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
protected org.nd4j.linalg.api.rng.Random rng;
|
protected org.nd4j.linalg.api.rng.Random rng;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
|
@ -32,6 +33,7 @@ import java.util.*;
|
||||||
import org.bytedeco.opencv.opencv_core.*;
|
import org.bytedeco.opencv.opencv_core.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class PipelineImageTransform extends BaseImageTransform<Mat> {
|
public class PipelineImageTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
protected List<Pair<ImageTransform, Double>> imageTransforms;
|
protected List<Pair<ImageTransform, Double>> imageTransforms;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
||||||
@JsonIgnoreProperties({"rng", "converter"})
|
@JsonIgnoreProperties({"rng", "converter"})
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class RandomCropTransform extends BaseImageTransform<Mat> {
|
public class RandomCropTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
protected int outputHeight;
|
protected int outputHeight;
|
||||||
|
|
|
@ -20,11 +20,8 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.functions;
|
package org.nd4j.autodiff.functions;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.*;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -46,7 +46,9 @@ public class SDVariable implements Serializable {
|
||||||
protected SameDiff sameDiff;
|
protected SameDiff sameDiff;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
@Setter
|
||||||
protected String varName;
|
protected String varName;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
protected VariableType variableType;
|
protected VariableType variableType;
|
||||||
|
@ -83,18 +85,6 @@ public class SDVariable implements Serializable {
|
||||||
return varName;
|
return varName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setVarName(String varName) {
|
|
||||||
this.varName = varName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link #name()}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public String getVarName(){
|
|
||||||
return name();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if this variable is a place holder
|
* Returns true if this variable is a place holder
|
||||||
* @return
|
* @return
|
||||||
|
|
|
@ -39,5 +39,6 @@ public class Variable {
|
||||||
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||||
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
|
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
|
||||||
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
||||||
|
@Builder.Default
|
||||||
protected int variableIndex = -1;
|
protected int variableIndex = -1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class SameDiffUtils {
|
||||||
|
|
||||||
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
|
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
|
||||||
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
|
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
|
||||||
" be specified when using external errors: got %s", inputs);
|
" be specified when using external errors: got %s", (Object) inputs);
|
||||||
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
|
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
|
||||||
fn.outputVariable();
|
fn.outputVariable();
|
||||||
return fn;
|
return fn;
|
||||||
|
|
|
@ -49,7 +49,7 @@ import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
|
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
|
||||||
|
|
||||||
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
|
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
|
||||||
|
|
|
@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Histogram extends BaseHistogram {
|
public class Histogram extends BaseHistogram {
|
||||||
private final String title;
|
private final String title;
|
||||||
private final double lower;
|
private final double lower;
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
package org.nd4j.linalg.api.memory.deallocation;
|
package org.nd4j.linalg.api.memory.deallocation;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.memory.Deallocatable;
|
import org.nd4j.linalg.api.memory.Deallocatable;
|
||||||
import org.nd4j.linalg.api.memory.Deallocator;
|
import org.nd4j.linalg.api.memory.Deallocator;
|
||||||
|
|
||||||
|
@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue;
|
||||||
import java.lang.ref.WeakReference;
|
import java.lang.ref.WeakReference;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class DeallocatableReference extends WeakReference<Deallocatable> {
|
public class DeallocatableReference extends WeakReference<Deallocatable> {
|
||||||
private String id;
|
private String id;
|
||||||
private Deallocator deallocator;
|
private Deallocator deallocator;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops;
|
package org.nd4j.linalg.api.ops;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
|
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,9 @@ import java.lang.reflect.Array;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Builder
|
||||||
|
@AllArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
private String opName;
|
private String opName;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -35,6 +36,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Flatten extends DynamicCustomOp {
|
public class Flatten extends DynamicCustomOp {
|
||||||
private int order;
|
private int order;
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public abstract class BaseCompatOp extends DynamicCustomOp {
|
public abstract class BaseCompatOp extends DynamicCustomOp {
|
||||||
protected String frameName;
|
protected String frameName;
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Enter extends BaseCompatOp {
|
public class Enter extends BaseCompatOp {
|
||||||
|
|
||||||
protected boolean isConstant;
|
protected boolean isConstant;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class While extends BaseCompatOp {
|
public class While extends BaseCompatOp {
|
||||||
|
|
||||||
protected boolean isConstant;
|
protected boolean isConstant;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class FirstIndex extends BaseIndexAccumulation {
|
public class FirstIndex extends BaseIndexAccumulation {
|
||||||
protected Condition condition;
|
protected Condition condition;
|
||||||
protected double compare;
|
protected double compare;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class LastIndex extends BaseIndexAccumulation {
|
public class LastIndex extends BaseIndexAccumulation {
|
||||||
protected Condition condition;
|
protected Condition condition;
|
||||||
protected double compare;
|
protected double compare;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgAmax extends DynamicCustomOp {
|
public class ArgAmax extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgAmin extends DynamicCustomOp {
|
public class ArgAmin extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgMax extends DynamicCustomOp {
|
public class ArgMax extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgMin extends DynamicCustomOp {
|
public class ArgMin extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
import lombok.*;
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Conv1DConfig extends BaseConvolutionConfig {
|
public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCW = "NCW";
|
public static final String NCW = "NCW";
|
||||||
public static final String NWC = "NWC";
|
public static final String NWC = "NWC";
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.enums.WeightsFormat;
|
import org.nd4j.enums.WeightsFormat;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Conv2DConfig extends BaseConvolutionConfig {
|
public class Conv2DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCHW = "NCHW";
|
public static final String NCHW = "NCHW";
|
||||||
public static final String NHWC = "NHWC";
|
public static final String NHWC = "NHWC";
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Conv3DConfig extends BaseConvolutionConfig {
|
public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NDHWC = "NDHWC";
|
public static final String NDHWC = "NDHWC";
|
||||||
public static final String NCDHW = "NCDHW";
|
public static final String NCDHW = "NCDHW";
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class DeConv2DConfig extends BaseConvolutionConfig {
|
public class DeConv2DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCHW = "NCHW";
|
public static final String NCHW = "NCHW";
|
||||||
public static final String NHWC = "NHWC";
|
public static final String NHWC = "NHWC";
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class DeConv3DConfig extends BaseConvolutionConfig {
|
public class DeConv3DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCDHW = "NCDHW";
|
public static final String NCDHW = "NCDHW";
|
||||||
public static final String NDHWC = "NDHWC";
|
public static final String NDHWC = "NDHWC";
|
||||||
|
|
|
@ -24,12 +24,14 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
||||||
|
|
||||||
private double alpha, beta, bias;
|
private double alpha, beta, bias;
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
|
||||||
|
@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Pooling2DConfig extends BaseConvolutionConfig {
|
public class Pooling2DConfig extends BaseConvolutionConfig {
|
||||||
|
|
||||||
@Builder.Default private long kH = -1, kW = -1;
|
@Builder.Default private long kH = -1, kW = -1;
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Pooling3DConfig extends BaseConvolutionConfig {
|
public class Pooling3DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
|
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
|
||||||
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
|
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
|
||||||
|
|
|
@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Mmul extends DynamicCustomOp {
|
public class Mmul extends DynamicCustomOp {
|
||||||
|
|
||||||
protected MMulTranspose mt;
|
protected MMulTranspose mt;
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class MmulBp extends DynamicCustomOp {
|
public class MmulBp extends DynamicCustomOp {
|
||||||
|
|
||||||
protected MMulTranspose mt;
|
protected MMulTranspose mt;
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class BatchMmul extends DynamicCustomOp {
|
public class BatchMmul extends DynamicCustomOp {
|
||||||
|
|
||||||
protected int transposeA;
|
protected int transposeA;
|
||||||
|
|
|
@ -39,10 +39,15 @@ import java.util.Map;
|
||||||
public class BalanceMinibatches {
|
public class BalanceMinibatches {
|
||||||
private DataSetIterator dataSetIterator;
|
private DataSetIterator dataSetIterator;
|
||||||
private int numLabels;
|
private int numLabels;
|
||||||
|
@Builder.Default
|
||||||
private Map<Integer, List<File>> paths = Maps.newHashMap();
|
private Map<Integer, List<File>> paths = Maps.newHashMap();
|
||||||
|
@Builder.Default
|
||||||
private int miniBatchSize = -1;
|
private int miniBatchSize = -1;
|
||||||
|
@Builder.Default
|
||||||
private File rootDir = new File("minibatches");
|
private File rootDir = new File("minibatches");
|
||||||
|
@Builder.Default
|
||||||
private File rootSaveDir = new File("minibatchessave");
|
private File rootSaveDir = new File("minibatchessave");
|
||||||
|
@Builder.Default
|
||||||
private List<File> labelRootDirs = new ArrayList<>();
|
private List<File> labelRootDirs = new ArrayList<>();
|
||||||
private DataNormalization dataNormalization;
|
private DataNormalization dataNormalization;
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,8 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.learning.config;
|
package org.nd4j.linalg.learning.config;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.learning.AdaMaxUpdater;
|
import org.nd4j.linalg.learning.AdaMaxUpdater;
|
||||||
import org.nd4j.linalg.learning.GradientUpdater;
|
import org.nd4j.linalg.learning.GradientUpdater;
|
||||||
|
@ -44,7 +45,8 @@ public class AdaMax implements IUpdater {
|
||||||
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
|
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
|
||||||
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
|
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
|
||||||
|
|
||||||
@lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
|
@lombok.Builder.Default
|
||||||
|
private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
|
||||||
private ISchedule learningRateSchedule;
|
private ISchedule learningRateSchedule;
|
||||||
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
|
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
|
||||||
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate
|
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate
|
||||||
|
|
|
@ -335,20 +335,6 @@ public class OpProfiler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Dev-time method.
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
protected StackAggregator getMixedOrderAggregator() {
|
|
||||||
// FIXME: remove this method, or make it protected
|
|
||||||
return mixedOrderAggregator;
|
|
||||||
}
|
|
||||||
|
|
||||||
public StackAggregator getScalarAggregator() {
|
|
||||||
return scalarAggregator;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void updatePairs(String opName, String opClass) {
|
protected void updatePairs(String opName, String opClass) {
|
||||||
// now we save pairs of ops/classes
|
// now we save pairs of ops/classes
|
||||||
String cOpNameKey = prevOpName + " -> " + opName;
|
String cOpNameKey = prevOpName + " -> " + opName;
|
||||||
|
|
|
@ -25,4 +25,5 @@ dependencies {
|
||||||
|
|
||||||
implementation "org.slf4j:slf4j-api"
|
implementation "org.slf4j:slf4j-api"
|
||||||
implementation "org.apache.commons:commons-lang3"
|
implementation "org.apache.commons:commons-lang3"
|
||||||
|
implementation "com.fasterxml.jackson.core:jackson-annotations"
|
||||||
}
|
}
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers;
|
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
|
@ -46,6 +47,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
|
@ -39,6 +40,7 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class KerasMasking extends KerasLayer {
|
public class KerasMasking extends KerasLayer {
|
||||||
|
|
||||||
private double maskingValue;
|
private double maskingValue;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||||
|
@ -35,6 +36,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class KerasMerge extends KerasLayer {
|
public class KerasMerge extends KerasLayer {
|
||||||
|
|
||||||
private final String LAYER_FIELD_MODE = "mode";
|
private final String LAYER_FIELD_MODE = "mode";
|
||||||
|
|
|
@ -55,11 +55,13 @@ public class VPTree implements Serializable {
|
||||||
private Node root;
|
private Node root;
|
||||||
private String similarityFunction;
|
private String similarityFunction;
|
||||||
@Getter
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
private boolean invert = false;
|
private boolean invert = false;
|
||||||
private transient ExecutorService executorService;
|
private transient ExecutorService executorService;
|
||||||
@Getter
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
private int workers = 1;
|
private int workers = 1;
|
||||||
private AtomicInteger size = new AtomicInteger(0);
|
@Builder.Default private AtomicInteger size = new AtomicInteger(0);
|
||||||
|
|
||||||
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();
|
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();
|
||||||
|
|
||||||
|
|
|
@ -21,11 +21,13 @@
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.evaluation.curves.BaseHistogram;
|
import org.nd4j.evaluation.curves.BaseHistogram;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Histogram extends org.nd4j.evaluation.curves.Histogram {
|
public class Histogram extends org.nd4j.evaluation.curves.Histogram {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -76,8 +76,9 @@ import static com.google.common.base.Preconditions.checkArgument;
|
||||||
@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
|
@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
|
||||||
"trainingMasterUID"})
|
"trainingMasterUID"})
|
||||||
@EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
|
@EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath",
|
||||||
"trainingMasterUID"})
|
"trainingMasterUID"}, callSuper = false)
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
|
||||||
public class ParameterAveragingTrainingMaster
|
public class ParameterAveragingTrainingMaster
|
||||||
extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker>
|
extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker>
|
||||||
implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {
|
implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.spark.parameterserver.training;
|
package org.deeplearning4j.spark.parameterserver.training;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -100,7 +101,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker>
|
public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker>
|
||||||
implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
|
implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
|
||||||
//Static counter/id fields used to determine which training master last set up the singleton param servers, etc
|
//Static counter/id fields used to determine which training master last set up the singleton param servers, etc
|
||||||
|
|
|
@ -53,16 +53,16 @@ public class ParameterServerClient implements NDArrayCallback {
|
||||||
//port to listen on for the subscriber
|
//port to listen on for the subscriber
|
||||||
private int subscriberPort;
|
private int subscriberPort;
|
||||||
//the stream to listen on for the subscriber
|
//the stream to listen on for the subscriber
|
||||||
private int subscriberStream = 11;
|
@Builder.Default private int subscriberStream = 11;
|
||||||
//the "current" ndarray
|
//the "current" ndarray
|
||||||
private AtomicReference<INDArray> arr;
|
private AtomicReference<INDArray> arr;
|
||||||
private INDArray none = Nd4j.scalar(1.0);
|
@Builder.Default private INDArray none = Nd4j.scalar(1.0);
|
||||||
private AtomicBoolean running;
|
private AtomicBoolean running;
|
||||||
private String masterStatusHost;
|
private String masterStatusHost;
|
||||||
private int masterStatusPort;
|
private int masterStatusPort;
|
||||||
private ObjectMapper objectMapper = new ObjectMapper();
|
@Builder.Default private ObjectMapper objectMapper = new ObjectMapper();
|
||||||
private Aeron aeron;
|
private Aeron aeron;
|
||||||
private boolean compressArray = true;
|
@Builder.Default private boolean compressArray = true;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tracks number of
|
* Tracks number of
|
||||||
|
|
|
@ -47,7 +47,7 @@ public class TextGenerationLSTM extends ZooModel {
|
||||||
@Builder.Default private long seed = 1234;
|
@Builder.Default private long seed = 1234;
|
||||||
@Builder.Default private int maxLength = 40;
|
@Builder.Default private int maxLength = 40;
|
||||||
@Builder.Default private int totalUniqueCharacters = 47;
|
@Builder.Default private int totalUniqueCharacters = 47;
|
||||||
private int[] inputShape = new int[] {maxLength, totalUniqueCharacters};
|
@Builder.Default private int[] inputShape = new int[] {maxLength, totalUniqueCharacters};
|
||||||
@Builder.Default private IUpdater updater = new RmsProp(0.01);
|
@Builder.Default private IUpdater updater = new RmsProp(0.01);
|
||||||
@Builder.Default private CacheMode cacheMode = CacheMode.NONE;
|
@Builder.Default private CacheMode cacheMode = CacheMode.NONE;
|
||||||
@Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
|
@Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
|
||||||
|
|
Loading…
Reference in New Issue