diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java index c48048fa7..6382bb808 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java @@ -24,11 +24,10 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -139,7 +138,7 @@ public class AlphaDropout implements IDropout { //a * (x * d + alphaPrime * (1-d)) + b INDArray inverseMask = mask.rsub(1.0); INDArray aPOneMinusD = inverseMask.muli(alphaPrime); - Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, mask, output)); //out = x * d + Nd4j.getExecutioner().exec(new MulOp(inputActivations, mask, output)); //out = x * d output.addi(aPOneMinusD).muli(a).addi(b); //Nd4j.getExecutioner().exec(new AlphaDropOut(inputActivations, output, p, a, alphaPrime, b)); @@ -152,7 +151,7 @@ public class AlphaDropout implements IDropout { //dL/dIn = dL/dOut * dOut/dIn // dOut/dIn = 0 if dropped (d=0), or a otherwise (d=1) mask.muli(a); - Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, mask, gradAtInput)); + Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, mask, gradAtInput)); mask = null; return gradAtInput; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java index 899b4382a..f9af153ad 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java @@ -24,7 +24,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.schedule.ISchedule; @@ -153,7 +153,7 @@ public class Dropout implements IDropout { mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()).assign(1.0); Nd4j.getExecutioner().exec(new DropOutInverted(mask, mask, currP)); - Nd4j.getExecutioner().exec(new OldMulOp(inputCast, mask, output)); + Nd4j.getExecutioner().exec(new MulOp(inputCast, mask, output)); return output; } @@ -171,7 +171,7 @@ public class Dropout implements IDropout { if(m.dataType() != gradAtInput.dataType()){ m = m.castTo(gradAtInput.dataType()); } - Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, m, gradAtInput)); + Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, m, gradAtInput)); mask = null; return gradAtInput; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java index 4f364ee36..d42b79d29 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java @@ -22,7 +22,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.schedule.ISchedule; @@ -88,7 +88,7 @@ public class GaussianDropout implements IDropout { noise = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), inputActivations.shape(), inputActivations.ordering()); Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 1.0, stdev)); - return Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, noise, output)); + return Nd4j.getExecutioner().exec(new MulOp(inputActivations, noise, output))[0]; } @Override @@ -96,7 +96,7 @@ public class GaussianDropout implements IDropout { Preconditions.checkState(noise != null, "Cannot perform backprop: GaussianDropout noise array is absent (already cleared?)"); //out = in*y, where y ~ N(1, stdev) //dL/dIn = dL/dOut * dOut/dIn = y * dL/dOut - Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, noise, gradAtInput)); + Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, noise, gradAtInput)); noise = null; return gradAtInput; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java index 89b5edf6c..d165614ab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java @@ -19,7 +19,7 @@ package org.deeplearning4j.nn.conf.dropout; import lombok.Data; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.schedule.ISchedule; @@ -69,7 +69,7 @@ public class GaussianNoise implements IDropout { INDArray noise = Nd4j.createUninitialized(output.dataType(), inputActivations.shape(), inputActivations.ordering()); Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 0, currS)); - Nd4j.getExecutioner().exec(new OldAddOp(inputActivations, noise, output)); + Nd4j.getExecutioner().exec(new AddOp(inputActivations, noise, output)); return output; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java index 509c3bfb1..8671bef66 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java @@ -24,7 +24,7 @@ import org.nd4j.linalg.activations.impl.ActivationHardSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; @@ -144,7 +144,7 @@ public class BernoulliReconstructionDistribution implements ReconstructionDistri INDArray out = Nd4j.createUninitialized(DataType.BOOL, p.shape()); - Nd4j.getExecutioner().execAndReturn(new OldLessThan(rand, p, out)); + Nd4j.getExecutioner().execAndReturn(new LessThan(rand, p, out)); return out.castTo(DataType.FLOAT); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java index fea446247..20beffafd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java @@ -22,8 +22,8 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.Distributions; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.factory.Nd4j; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.ArrayType; @@ -86,9 +86,9 @@ public class WeightNoise implements IWeightNoise { INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering()); if (additive) { - Nd4j.getExecutioner().exec(new OldAddOp(param, noise,out)); + Nd4j.getExecutioner().exec(new AddOp(param, noise,out)); } else { - Nd4j.getExecutioner().exec(new OldMulOp(param, noise, out)); + Nd4j.getExecutioner().exec(new MulOp(param, noise, out)); } return out; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java index 5f78001d0..3fb2046c3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java @@ -34,8 +34,8 @@ import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; @@ -205,7 +205,7 @@ public class BatchNormalization extends BaseLayer(retGradient, nextEpsilon); } @@ -257,7 +256,7 @@ public class LocalResponseNormalization unitScale = sumPart.mul(alpha).addi(k); // y = x * unitScale**-beta scale = Transforms.pow(unitScale, -beta, true); - Nd4j.getExecutioner().exec(new OldMulOp(input, scale, activations)); + Nd4j.getExecutioner().exec(new MulOp(input, scale, activations)); } else { // unitScale = (k + alpha * sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 ) sumPart.muli(alpha, activations).addi(k); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java index 1682ab177..10cf9cde5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java @@ -35,7 +35,7 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -554,7 +554,7 @@ public class LSTMHelpers { //Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi INDArray deltao = deltaoNext; - Nd4j.getExecutioner().exec(new OldMulOp(nablaOut, sigmahOfS, deltao)); + Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao)); if (sigmoidGates) { INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().exec(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo deltao.muli(sigmaoPrimeOfZo); @@ -607,7 +607,7 @@ public class LSTMHelpers { deltag.muli(ai); deltag.muli(nablaCellState); } else { - INDArray temp2 = Nd4j.getExecutioner().exec(new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), ai.shape(), 'f'))); + INDArray temp2 = Nd4j.getExecutioner().exec(new MulOp(ai, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), ai.shape(), 'f')))[0]; deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst()); //TODO activation functions with params; optimize (no assign) } @@ -616,7 +616,7 @@ public class LSTMHelpers { //Network input delta: INDArray zi = fwdPass.iz[time]; INDArray deltai = deltaiNext; - temp = Nd4j.getExecutioner().exec(new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), deltai.shape(), 'f'))); + temp = Nd4j.getExecutioner().exec(new MulOp(ag, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), deltai.shape(), 'f')))[0]; deltai.assign(afn.backprop(zi, temp).getFirst()); //TODO activation functions with params; also: optimize this (no assign) //Shape: [m,n^L] diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index f4d167866..0f29bc837 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -26,7 +26,6 @@ import onnx.OnnxProto3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; -import org.nd4j.graph.DataType; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; @@ -520,7 +519,7 @@ public abstract class DifferentialFunction { * @return the arguments for a given function */ public SDVariable[] args() { - return sameDiff.getInputVariablesForFunction(this); + return sameDiff.getInputVariablesForOp(this); } /** @@ -661,7 +660,7 @@ public abstract class DifferentialFunction { } if(sameDiff != null && !(this instanceof SDVariable)) - sameDiff.putFunctionForId(ownName,this); + sameDiff.putOpForId(ownName,this); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java index 2f0dbd5b5..36334b648 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java @@ -23,6 +23,7 @@ import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,12 +40,12 @@ import org.nd4j.evaluation.IMetric; @Getter public class EvaluationRecord { - private ImmutableMap> evaluations; + private Map> evaluations; private Map, IEvaluation> classEvaluations = new HashMap<>(); private boolean isEmpty = true; public EvaluationRecord(Map> evaluations) { - this.evaluations = ImmutableMap.copyOf(evaluations); + this.evaluations = Collections.unmodifiableMap(evaluations); for (List le : evaluations.values()) { for (IEvaluation e : le) { @@ -68,7 +69,7 @@ public class EvaluationRecord { /** * Get all evaluations */ - public ImmutableMap> evaluations() { + public Map> evaluations() { return evaluations; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java index f43d41841..f0dcecb49 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java @@ -16,8 +16,8 @@ package org.nd4j.autodiff.listeners.records; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import lombok.Getter; @@ -49,11 +49,11 @@ public class History { public History(List training, List validation, LossCurve loss, long trainingTimeMillis, List validationTimesMillis){ - trainingHistory = ImmutableList.copyOf(training); - validationHistory = ImmutableList.copyOf(validation); + trainingHistory = Collections.unmodifiableList(training); + validationHistory = Collections.unmodifiableList(validation); this.lossCurve = loss; this.trainingTimeMillis = trainingTimeMillis; - this.validationTimesMillis = ImmutableList.copyOf(validationTimesMillis); + this.validationTimesMillis = Collections.unmodifiableList(validationTimesMillis); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java index a65efe180..493950bbf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java @@ -16,8 +16,8 @@ package org.nd4j.autodiff.listeners.records; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import lombok.Getter; import lombok.NonNull; @@ -35,7 +35,7 @@ public class LossCurve { private INDArray lossValues; public LossCurve(List losses){ - lossNames = ImmutableList.copyOf(losses.get(0).getLossNames()); + lossNames = Collections.unmodifiableList(losses.get(0).getLossNames()); int numLossValues = losses.get(0).lossValues().length; lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 9cfe87822..a18fd6b06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -466,8 +466,11 @@ public class SameDiff extends SDBaseOps { } /** - * Set the current {@link Listener} instances. - * Note that + * Set the current SameDiff-wide {@link Listener} instances. + * + * Note that this will overwrite the current listener list. + * If you want to use additional listeners for a single operation, + * use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}). * * @param listeners Listeners */ @@ -476,19 +479,37 @@ public class SameDiff extends SDBaseOps { addListeners(listeners); } + /** + * See {@link #setListeners(Listener...)}. + */ public void setListeners(Collection listeners) { this.listeners.clear(); addListeners(listeners); } + + /** + * Add SameDiff-wide {@link Listener} instances. + * + * If you want to use additional listeners for a single operation, + * use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}). + * + * @param listeners Listeners + */ public void addListeners(Listener... listeners) { addListeners(Arrays.asList(listeners)); } + /** + * See {@link #addListeners(Listener...)}. + */ public void addListeners(Collection listeners) { this.listeners.addAll(listeners); } + /** + * Gets the current SameDiff-wide listeners. + */ public List getListeners() { return listeners; } @@ -585,6 +606,9 @@ public class SameDiff extends SDBaseOps { } + /** + * Gets all operations in a given name scope. + */ public List getOpsInScope(NameScope scope) { ArrayList ops = new ArrayList<>(); for (SameDiffOp v : this.ops.values()) { @@ -594,6 +618,16 @@ public class SameDiff extends SDBaseOps { return ops; } + /** + * See {@link #getOpsInScope(NameScope)}. + */ + public List getOpsInScope(String scope){ + return getOpsInScope(new NameScope(this, scope)); + } + + /** + * Gets all variables in a given name scope. + */ public List getVariablesInScope(NameScope scope) { ArrayList vars = new ArrayList<>(); for (SDVariable v : variables()) { @@ -603,6 +637,13 @@ public class SameDiff extends SDBaseOps { return vars; } + /** + * See {@link #getVariablesInScope(NameScope)}. + */ + public List getVariablesInScope(String scope){ + return getVariablesInScope(new NameScope(this, scope)); + } + /** * @param sameDiff * @return @@ -638,8 +679,8 @@ public class SameDiff extends SDBaseOps { function.getSameDiff()); clone.setSameDiff(sameDiff); clone.setOwnName(function.getOwnName()); - if (sameDiff.functionExists(function.getOwnName())) - sameDiff.putFunctionForId(function.getOwnName(), function); + if (sameDiff.opExists(function.getOwnName())) + sameDiff.putOpForId(function.getOwnName(), function); newFunctions.put(function.getOwnName(), clone); val argsForFunction = function.args(); @@ -672,17 +713,21 @@ public class SameDiff extends SDBaseOps { * @param id the function id to test for * @return true if the function id exists, false otherwise */ - public boolean functionExists(String id) { + public boolean opExists(String id) { return ops.containsKey(id); } - public DifferentialFunction functionOutputFor(String varName) { - if (variables.get(varName).getOutputOfOp() == null) + /** + * Get the differential function (if any) that this variable is the output for + * + * @param variableName Name of the variable + * @return The differential function that this variable is an output of, or null if it is not the output of a function + */ + public DifferentialFunction getVariableOutputOp(String variableName) { + Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName); + if (variables.get(variableName).getOutputOfOp() == null) return null; - String outName = variables.get(varName).getOutputOfOp(); - if (outName == null) - return null; - return ops.get(outName).getOp(); + return ops.get(variables.get(variableName).getOutputOfOp()).getOp(); } /** @@ -691,7 +736,7 @@ public class SameDiff extends SDBaseOps { * @param id the id of the function * @return the function for the given id if it exists */ - public DifferentialFunction getFunctionById(@NonNull String id) { + public DifferentialFunction getOpById(@NonNull String id) { if (!ops.containsKey(id)) { throw new ND4JIllegalStateException("No function with id " + id + " found!"); } @@ -705,7 +750,7 @@ public class SameDiff extends SDBaseOps { * @param id the id of the function * @param function the function */ - public void putFunctionForId(String id, DifferentialFunction function) { + public void putOpForId(String id, DifferentialFunction function) { if (ops.containsKey(id) && ops.get(id).getOp() == null) { throw new ND4JIllegalStateException("Function by id already exists!"); } else if (function instanceof SDVariable) { @@ -726,7 +771,7 @@ public class SameDiff extends SDBaseOps { * @param function the function to get the inputs for * @return the input ids for a given function */ - public String[] getInputsForFunction(DifferentialFunction function) { + public String[] getInputsForOp(DifferentialFunction function) { if (!ops.containsKey(function.getOwnName())) throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); List inputs = ops.get(function.getOwnName()).getInputsToOp(); @@ -739,7 +784,7 @@ public class SameDiff extends SDBaseOps { * @param function the function to get the outputs for * @return the outputs ids for a given function */ - public String[] getOutputsForFunction(DifferentialFunction function) { + public String[] getOutputsForOp(DifferentialFunction function) { if (!ops.containsKey(function.getOwnName())) throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); List outputs = ops.get(function.getOwnName()).getOutputsOfOp(); @@ -753,8 +798,8 @@ public class SameDiff extends SDBaseOps { * @param function the function reference to get the output variable(s) for * @return the output variables for the given function */ - public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) { - val inputs = getOutputsForFunction(function); + public SDVariable[] getOutputVariablesForOp(DifferentialFunction function) { + val inputs = getOutputsForOp(function); if (inputs == null) { throw new ND4JIllegalStateException("No inputs found for function " + function); } @@ -774,8 +819,8 @@ public class SameDiff extends SDBaseOps { * @param function the function reference to get the input variable(s) for * @return the input variables for the given function */ - public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) { - val inputs = getInputsForFunction(function); + public SDVariable[] getInputVariablesForOp(DifferentialFunction function) { + val inputs = getInputsForOp(function); if (inputs == null) { throw new ND4JIllegalStateException("No inputs found for function " + function); } @@ -792,6 +837,10 @@ public class SameDiff extends SDBaseOps { } + /** + * Set the stored {@link INDArray} for a variable. Only works if the variable is of type + * {@link VariableType#CONSTANT}, {@link VariableType#PLACEHOLDER}, or {@link VariableType#VARIABLE}. + */ public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) { Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName); @@ -830,6 +879,9 @@ public class SameDiff extends SDBaseOps { return variableNameToShape.get(varName); } + /** + * See {@link #getShapeForVarName(String)}, but returns the shape descriptor. + */ public LongShapeDescriptor getShapeDescriptorForVarName(String varName) { if (getVariable(varName).getArr() != null) { return getVariable(varName).getArr().shapeDescriptor(); @@ -861,6 +913,9 @@ public class SameDiff extends SDBaseOps { } + /** + * Sets the shape descriptor for a variable. + */ public void putShapeForVarName(String varName, LongShapeDescriptor shape) { val v = getVariable(varName); putShapeForVarName(varName, shape.getShape()); @@ -1559,19 +1614,6 @@ public class SameDiff extends SDBaseOps { } - /** - * Get the differential function (if any) that this variable is the output for - * - * @param variableName Name of the variable - * @return The differential function that this variable is an output of, or null if it is not the output of a function - */ - public DifferentialFunction getVariableOutputFunction(String variableName) { - Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName); - if (variables.get(variableName).getOutputOfOp() == null) - return null; - return ops.get(variables.get(variableName).getOutputOfOp()).getOp(); - } - /** * Returns true if this function already has defined arguments @@ -1628,7 +1670,7 @@ public class SameDiff extends SDBaseOps { * * @return Array of differential functions */ - public DifferentialFunction[] functions() { + public DifferentialFunction[] ops() { List out = new ArrayList<>(ops.size()); for (SameDiffOp op : ops.values()) { out.add(op.getOp()); @@ -3143,10 +3185,18 @@ public class SameDiff extends SDBaseOps { placeholders, batch, requiredActivations, activeListeners, at); } + /** + * See {@link #one(String, DataType, int...)}. + * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). + */ public SDVariable one(String name, int... shape) { return one(name, Nd4j.defaultFloatingPointType(), shape); } + /** + * See {@link #one(String, DataType, long...)}. + * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). + */ public SDVariable one(String name, long... shape) { return one(name, Nd4j.defaultFloatingPointType(), shape); } @@ -3174,11 +3224,18 @@ public class SameDiff extends SDBaseOps { return var(name, new ConstantInitScheme('f', 1.0), dataType, shape); } - + /** + * See {@link #zero(String, DataType, long...)}. + * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). + */ public SDVariable zero(String name, long... shape) { return zero(name, Nd4j.defaultFloatingPointType(), shape); } + /** + * See {@link #zero(String, DataType, int...)}. + * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). + */ public SDVariable zero(String name, int... shape) { return zero(name, Nd4j.defaultFloatingPointType(), shape); } @@ -3293,6 +3350,18 @@ public class SameDiff extends SDBaseOps { } //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! + + /** + * Variable initialization with a specified {@link WeightInitScheme} + * This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. + * + * @param name the name of the variable + * @param variableType the SameDiff variable type of the variable (e.g. CONSTANT, PLACEHOLDER, etc.) + * @param weightInitScheme the weight initialization scheme + * @param dataType the data type of the variable (float, int, etc) + * @param shape the shape of the array to be created + * @return the created variable + */ public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { @@ -3932,7 +4001,7 @@ public class SameDiff extends SDBaseOps { * @param varName the variable name to remove * @param function the function to remove the argument from */ - public void removeArgFromFunction(String varName, DifferentialFunction function) { + public void removeArgFromOp(String varName, DifferentialFunction function) { val args = function.args(); for (int i = 0; i < args.length; i++) { @@ -4324,7 +4393,7 @@ public class SameDiff extends SDBaseOps { } //Update the internal state: outgoing variables for function - if (getOutputsForFunction(function) == null) + if (getOutputsForOp(function) == null) addOutgoingFor(ret, function); return ret; @@ -4357,7 +4426,7 @@ public class SameDiff extends SDBaseOps { //Update the internal state: outgoing variables for function - if (getOutputsForFunction(function) == null) + if (getOutputsForOp(function) == null) addOutgoingFor(ret, function); return ret; @@ -4428,7 +4497,9 @@ public class SameDiff extends SDBaseOps { .build(); } - + /** + * Create a new TensorArray. + */ public TensorArray tensorArray(DataType dataType) { TensorArray ta = new TensorArray(this, dataType); SDVariable[] outVars = ta.outputVariables(); @@ -4439,7 +4510,6 @@ public class SameDiff extends SDBaseOps { * @param functionName * @param with */ - public SDVariable invokeFunctionOn(String functionName, SameDiff with) { SameDiff instance = sameDiffFunctionInstances.get(functionName); SDVariable ret = instance.invokeGraphOn(with); @@ -5746,6 +5816,13 @@ public class SameDiff extends SDBaseOps { return bufferBuilder.dataBuffer(); } + /** + * See {@link #asFlatGraph(long, ExecutorConfiguration, boolean)}. + * + * Uses the default {@link ExecutorConfiguration} with output mode as + * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, + * with profiling disabled and gather timings enabled. + */ public FlatGraph asFlatGraph(boolean includeUpdaterState) { return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState)); } @@ -5765,6 +5842,10 @@ public class SameDiff extends SDBaseOps { * This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and * all arrays as a ByteBuffer containing the FlatBuffers format data * + * Uses the default {@link ExecutorConfiguration} with output mode as + * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, + * with profiling disabled and gather timings enabled. + * * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) * @return a ByteBuffer holding the exported FlatBuffers representation of the graph */ @@ -5870,7 +5951,11 @@ public class SameDiff extends SDBaseOps { /** * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later
- * This includes the updater state, if applicable + * This includes the updater state, if applicable. + * + * Uses the default {@link ExecutorConfiguration} with output mode as + * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, + * with profiling disabled and gather timings enabled. * * @param file File to save the FlatBuffers serialized graph (including arrays) to */ @@ -5878,6 +5963,13 @@ public class SameDiff extends SDBaseOps { asFlatFile(file, true); } + /** + * See {@link #asFlatFile(File, ExecutorConfiguration, boolean)}. + * + * Uses the default {@link ExecutorConfiguration} with output mode as + * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, + * with profiling disabled and gather timings enabled. + */ public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException { val fb = asFlatBuffers(withUpdaterState); val offset = fb.position(); @@ -5943,6 +6035,8 @@ public class SameDiff extends SDBaseOps { * instance from a byte buffers * instance. * + * See {@link #fromFlatBuffers(ByteBuffer, boolean)}. Loads updater state (loadUpdaterState is true). + * * @param bbIn the input byte buffer * @return the created samediff instance * @throws IOException @@ -5951,6 +6045,16 @@ public class SameDiff extends SDBaseOps { return fromFlatBuffers(bbIn, true); } + /** + * Create a {@link SameDiff} + * instance from a byte buffers + * instance. + * + * @param bbIn the input byte buffer + * @param loadUpdaterState If true, load the updater state (Adam etc state). For training, use true. For inference, use false + * @return the created samediff instance + * @throws IOException + */ public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException { FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn); @@ -6287,7 +6391,7 @@ public class SameDiff extends SDBaseOps { public String summary() { Map varMap = variableMap(); - DifferentialFunction[] functions = functions(); + DifferentialFunction[] functions = ops(); int countVarsWithArrays = 0; @@ -6324,7 +6428,7 @@ public class SameDiff extends SDBaseOps { if (outputOf == null) { outputOf = ""; } else { - DifferentialFunction d = getFunctionById(outputOf); + DifferentialFunction d = getOpById(outputOf); outputOf = d.getOwnName() + "(" + d.opName() + ")"; } outputOfFn.put(s, outputOf); @@ -6412,7 +6516,7 @@ public class SameDiff extends SDBaseOps { for (Map.Entry e : sameDiffFunctionInstances.entrySet()) { SameDiff sd = e.getValue(); int vars = sd.variableMap().size(); - int fns = (sd.functions() == null ? 0 : sd.functions().length); + int fns = (sd.ops() == null ? 0 : sd.ops().length); int defFns = sd.definedFunctionNames().size(); sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n"); @@ -6422,11 +6526,16 @@ public class SameDiff extends SDBaseOps { return sb.toString(); } - + /** + * Calculate data types for the variables in the graph + */ public Map calculateOutputDataTypes() { return calculateOutputDataTypes(false); } + /** + * Calculate data types for the variables in the graph + */ public Map calculateOutputDataTypes(boolean dynamicUpdate) { List allVars = new ArrayList<>(variables.keySet()); DataTypesSession session = new DataTypesSession(this, dynamicUpdate); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 8d806249d..387e25f48 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -325,7 +325,7 @@ public abstract class AbstractSession { } - } else if (sameDiff.getVariableOutputFunction(varToExec.getVariable()) != null) { + } else if (sameDiff.getVariableOutputOp(varToExec.getVariable()) != null) { //Variable is the output of an op -> execute op String opName = sameDiff.getVariables().get(varToExec.getVariable()).getOutputOfOp(); @@ -336,7 +336,7 @@ public abstract class AbstractSession { //Post execution: work out what is now available for exec - String[] opOutputVarNames = sameDiff.getFunctionById(opName).outputVariablesNames(); + String[] opOutputVarNames = sameDiff.getOpById(opName).outputVariablesNames(); Preconditions.checkState(opOutputValues.length == opOutputVarNames.length, "Unexpected number of outputs from executed op %s:" + " got %s outputs when %s outputs were expected (%s)", parameterizedOp.getClass().getSimpleName(), opOutputValues.length, @@ -423,10 +423,10 @@ public abstract class AbstractSession { //Note subgraph initially should include placeholders and constants while (!processingQueue.isEmpty()) { String varName = processingQueue.remove(); - String opName = (sameDiff.getVariableOutputFunction(varName) == null ? null : sameDiff.getVariableOutputFunction(varName).getOwnName()); + String opName = (sameDiff.getVariableOutputOp(varName) == null ? null : sameDiff.getVariableOutputOp(varName).getOwnName()); if (!subgraph.contains(varName)) { - String[] opInputs = opName == null ? null : sameDiff.getInputsForFunction(sameDiff.getFunctionById(opName)); + String[] opInputs = opName == null ? null : sameDiff.getInputsForOp(sameDiff.getOpById(opName)); List controlDeps = sameDiff.getVariables().get(varName).getControlDeps(); int numInputs = (opInputs == null ? 0 : opInputs.length); if (controlDeps != null) { @@ -457,7 +457,7 @@ public abstract class AbstractSession { if (opName != null) { //To execute op - and hence get this variable: need inputs to that op - String[] inputs = sameDiff.getInputsForFunction(sameDiff.getFunctionById(opName)); + String[] inputs = sameDiff.getInputsForOp(sameDiff.getOpById(opName)); for (String s2 : inputs) { if (!subgraph.contains(s2)) { processingQueue.add(s2); @@ -501,7 +501,7 @@ public abstract class AbstractSession { if (inputForOps != null) { for (String opName : inputForOps) { - DifferentialFunction fn = sameDiff.getFunctionById(opName); + DifferentialFunction fn = sameDiff.getOpById(opName); if (fn instanceof Merge) { //Merge op: available for execution when *any* of its inputs are available. But only mark it for exec once... List opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); @@ -888,7 +888,7 @@ public abstract class AbstractSession { //Mark that outVar needs this specific executedVar (i.e., specific frame/iteration) //However, in the case of enter nodes, they are available for ALL iterations (used in loop conditions, for example) Variable v = sameDiff.getVariables().get(inputVar.getVariable()); - boolean isEnter = sameDiff.getVariableOutputFunction(v.getVariable().getVarName()) instanceof Enter; + boolean isEnter = sameDiff.getVariableOutputOp(v.getVariable().getVarName()) instanceof Enter; if(isEnter){ VarId iter0 = forVariable; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java index 6eff336e8..56a6a406e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java @@ -59,7 +59,7 @@ public class DataTypesSession extends AbstractSession inputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues) { - DifferentialFunction df = sameDiff.getFunctionById(opName); + DifferentialFunction df = sameDiff.getOpById(opName); List inputDataTypes = new ArrayList<>(); for(SDVariable v : df.args()){ DataType dt = v.dataType(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 3d40e205a..e16dad580 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -16,7 +16,6 @@ package org.nd4j.autodiff.samediff.internal; -import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -121,12 +120,12 @@ public class InferenceSession extends AbstractSession 0){ SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); - ImmutableMap.Builder namedOutsBuilder = ImmutableMap.builder(); + Map namedOutsBuilder = new HashMap<>(); for(int i = 0 ; i < out.length ; i++) namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]); - Map namedOuts = namedOutsBuilder.build(); + Map namedOuts = Collections.unmodifiableMap(namedOutsBuilder); for(Listener l : listeners){ if(l.isActive(at.operation())) { @@ -223,7 +222,7 @@ public class InferenceSession extends AbstractSession Enter -> TensorArrayRead //TODO also TensorArrayWrite, scatter, etc?? - inTensorArray = sameDiff.getVariableOutputFunction(inTensorArray.getVarName()).arg(); + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); v = newVarId(inTensorArray.getVarName(), v.getParentFrame()); } @@ -300,10 +299,10 @@ public class InferenceSession extends AbstractSession Enter -> TensorArrayWrite //TODO also TensorArrayScatter, etc?? - inTensorArray = sameDiff.getVariableOutputFunction(inTensorArray.getVarName()).arg(); + inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame()); } @@ -405,7 +404,7 @@ public class InferenceSession extends AbstractSession opInputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues) { - DifferentialFunction df = sameDiff.getFunctionById(opName); + DifferentialFunction df = sameDiff.getOpById(opName); //TODO We should clone these ops - probably - as we don't want them shared between threads/sessions! //But let's only clone them *once* and cache in inference session - not on every exec diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index b23dd576b..1bde174c1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -18,7 +18,6 @@ package org.nd4j.autodiff.samediff.ops; import com.google.common.collect.Sets; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; import java.util.Set; import lombok.NonNull; @@ -27,7 +26,6 @@ import org.nd4j.autodiff.samediff.ArgumentInterceptor; import org.nd4j.autodiff.samediff.NameScope; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; import org.nd4j.autodiff.samediff.SameDiffLambda; import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda; import org.nd4j.autodiff.samediff.SameDiffSingleLambda; @@ -3377,7 +3375,7 @@ public abstract class SDBaseOps { for(SameDiffOp op : sd().getOpsInScope(ifScope)) { for(String in : op.getInputsToOp()){ - sd().removeArgFromFunction(in, op.getOp()); + sd().removeArgFromOp(in, op.getOp()); } sd().getOps().remove(op.getName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index 8a203c989..fab50a937 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -385,6 +385,29 @@ public class SDCNN extends SDOps { return updateVariableNameAndReference(ret, name); } + /** + * 3D CNN deconvolution operation with or without optional bias + * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] + * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] + * @param config Configuration + */ + public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { + return deconv3d(null, input, weights, bias, config); + } + + /** + * 3D CNN deconvolution operation with no bias + * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] + * @param config Configuration + */ + public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) { + return deconv3d(input, weights, null, config); + } + /** * Convolution 2d layer batch to space operation on 4d input.
* Reduces input channels dimension by rearranging data into a larger spatial dimensions
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index a33a26f67..c69295cc6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -199,7 +199,7 @@ public class LegacyOpMapper { case 25: return Or.class; case 26: - return OldAtan2Op.class; + throw new UnsupportedOperationException("OldATan2 (op number " + opNum + ") is no longer supported."); case 27: return LogicalOr.class; case 28: @@ -243,7 +243,7 @@ public class LegacyOpMapper { case 18: return Floor.class; case 20: - return OldReverse.class; + throw new UnsupportedOperationException("OldReverse (op number " + opNum + ") is no longer supported."); default: throw new UnsupportedOperationException("No known transform same op for op number: " + opNum); } @@ -581,19 +581,19 @@ public class LegacyOpMapper { public static Class pairwiseOpClass(int opNum){ switch (opNum){ case 0: - return OldAddOp.class; + throw new UnsupportedOperationException("OldFModOp (op number " + opNum + ") is no longer supported."); case 1: return CopyOp.class; case 2: - return OldDivOp.class; + throw new UnsupportedOperationException("OldDivOp (op number " + opNum + ") is no longer supported."); case 3: - return OldEqualTo.class; + throw new UnsupportedOperationException("OldEqualTo (op number " + opNum + ") is no longer supported."); case 4: - return OldGreaterThan.class; + throw new UnsupportedOperationException("OldGreaterThan (op number " + opNum + ") is no longer supported."); case 5: - return OldLessThan.class; + throw new UnsupportedOperationException("OldLessThan (op number " + opNum + ") is no longer supported."); case 6: - return OldMulOp.class; + throw new UnsupportedOperationException("OldMulOp (op number " + opNum + ") is no longer supported."); case 7: return Pow.class; case 8: @@ -603,15 +603,15 @@ public class LegacyOpMapper { case 10: return Eps.class; case 11: - return OldGreaterThanOrEqual.class; + throw new UnsupportedOperationException("OldGreaterThanOrEqual (op number " + opNum + ") is no longer supported."); case 12: - return OldLessThanOrEqual.class; + throw new UnsupportedOperationException("OldLessThanOrEqual (op number " + opNum + ") is no longer supported."); case 13: - return OldMax.class; + throw new UnsupportedOperationException("OldMax (op number " + opNum + ") is no longer supported."); case 14: - return OldMin.class; + throw new UnsupportedOperationException("OldMin (op number " + opNum + ") is no longer supported."); case 15: - return OldNotEqualTo.class; + throw new UnsupportedOperationException("OldNotEqualTo (op number " + opNum + ") is no longer supported."); case 16: return Set.class; case 17: @@ -631,11 +631,11 @@ public class LegacyOpMapper { case 59: return RemainderOp.class; case 60: - return OldFModOp.class; + throw new UnsupportedOperationException("OldFModOp (op number " + opNum + ") is no longer supported."); case 69: - return OldAtan2Op.class; + throw new UnsupportedOperationException("OldATan2 (op number " + opNum + ") is no longer supported."); case 20: - return OldFloorDivOp.class; + throw new UnsupportedOperationException("OldFloorDivOp (op number " + opNum + ") is no longer supported."); case 26: return RelativeError.class; case 27: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java index 3a736e93e..afe8551f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/GraphTransformUtil.java @@ -78,7 +78,7 @@ public class GraphTransformUtil { if (oldInputsForOps != null) { List newInputsForOps = new ArrayList<>(); for (String s : oldInputsForOps) { - DifferentialFunction df = sd.getFunctionById(s); + DifferentialFunction df = sd.getOpById(s); if (!allSubGraphFns.contains(df)) { newInputsForOps.add(s); } @@ -141,7 +141,7 @@ public class GraphTransformUtil { // (1) variable is (was) input to op that has been removed - just remove from list // (2) variable is now connected directly as an output: (A->B->C) becomes (A->C) // For the latter case, this - DifferentialFunction df = sd.getFunctionById(opName); + DifferentialFunction df = sd.getOpById(opName); if (allSubGraphFns.contains(df)) { newInputsForOp.remove(opName); } @@ -178,7 +178,7 @@ public class GraphTransformUtil { */ public static List getSubgraphsMatching(SameDiff sd, SubGraphPredicate p) { List out = new ArrayList<>(); - for (DifferentialFunction df : sd.functions()) { + for (DifferentialFunction df : sd.ops()) { if (p.matches(sd, df)) { SubGraph sg = p.getSubGraph(sd, df); out.add(sg); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java index fab19ae28..c9f1f52bf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraph.java @@ -20,7 +20,6 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; -import org.apache.commons.lang3.builder.Diff; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -68,7 +67,7 @@ public class SubGraph { boolean allInSubgraph = true; if(inputsFor != null){ for(String opOwnName : inputsFor) { - if (!inSubgraph(sameDiff.getFunctionById(opOwnName))){ + if (!inSubgraph(sameDiff.getOpById(opOwnName))){ allInSubgraph = false; break; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java index 7dcd18bd5..5d7e117a2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/SubGraphPredicate.java @@ -77,7 +77,7 @@ public class SubGraphPredicate extends OpPredicate { } SDVariable in = inputs[inNum]; - DifferentialFunction df = sameDiff.getVariableOutputFunction(in.getVarName()); + DifferentialFunction df = sameDiff.getVariableOutputOp(in.getVarName()); if (df == null || !e.getValue().matches(sameDiff, df)) { return false; } @@ -103,7 +103,7 @@ public class SubGraphPredicate extends OpPredicate { for(Map.Entry entry : opInputSubgraphPredicates.entrySet()){ OpPredicate p2 = entry.getValue(); SDVariable arg = rootFn.arg(entry.getKey()); - DifferentialFunction df = sd.getVariableOutputFunction(arg.getVarName()); + DifferentialFunction df = sd.getVariableOutputOp(arg.getVarName()); if(df != null){ childNodes.add(df); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index 02c76a3b2..b8625afde 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -28,7 +28,6 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener; import org.nd4j.base.Preconditions; -import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -106,7 +105,7 @@ public class GradCheckUtil { } Set fnOutputs = new HashSet<>(); - for(DifferentialFunction f : sd.functions()){ + for(DifferentialFunction f : sd.ops()){ for(SDVariable s : f.outputVariables()){ fnOutputs.add(s.getVarName()); } @@ -593,7 +592,7 @@ public class GradCheckUtil { 4. Gradient function: should contain all of the existing functions, and more */ - DifferentialFunction[] dfs = sd.functions(); + DifferentialFunction[] dfs = sd.ops(); List vars = sd.variables(); Set varSetStr = new HashSet<>(); @@ -661,7 +660,7 @@ public class GradCheckUtil { //Check that all original functions are present in the gradient function for(DifferentialFunction dfOrig : dfs){ - Preconditions.checkNotNull(gradFn.getFunctionById(dfOrig.getOwnName()), "DifferentialFunction " + dfOrig.getOwnName() + Preconditions.checkNotNull(gradFn.getOpById(dfOrig.getOwnName()), "DifferentialFunction " + dfOrig.getOwnName() + " from original SameDiff instance not present in grad fn"); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 5ea26b0ab..e9ad61c04 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -79,7 +79,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative; @@ -94,7 +93,6 @@ import org.nd4j.linalg.api.ops.random.impl.Linspace; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.function.Function; -import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.primitives.Pair; import org.tensorflow.framework.OpDef; @@ -464,7 +462,7 @@ public class OpValidation { //i.e., don't double count if a SameDiff instance has multiple copies of the same op type //Collect coverage information for backprop: - DifferentialFunction[] functions = sd.functions(); + DifferentialFunction[] functions = sd.ops(); Set backpropSeen = new HashSet<>(); for (DifferentialFunction df : functions) { backpropSeen.add(df.getClass()); @@ -481,7 +479,7 @@ public class OpValidation { if (testCase.fwdTestFns() != null) { for (String s : testCase.fwdTestFns().keySet()) { //Determine the differential function that this variable is the output of, if any - DifferentialFunction df = sd.getVariableOutputFunction(s); + DifferentialFunction df = sd.getVariableOutputOp(s); if (df != null) { if (seen == null) seen = new HashSet<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java index 7e4de14c0..63a5a012a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java @@ -31,7 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -708,8 +708,8 @@ public class ROC extends BaseEvaluation { itp = isTruePositive; ifp = isFalsePositive; } else { - isTruePositive = Nd4j.getExecutioner().exec(new OldMulOp(predictedClass1, positiveActualClassColumn, itp)); - isFalsePositive = Nd4j.getExecutioner().exec(new OldMulOp(predictedClass1, negativeActualClassColumn, ifp)); + isTruePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, positiveActualClassColumn, itp))[0]; + isFalsePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, negativeActualClassColumn, ifp))[0]; } //Counts for this batch: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index 1231fcb37..00f4964fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -68,17 +68,6 @@ public class DifferentialFunctionClassHolder { add("sameDiff"); add("ownName"); }}; - private static final Set classesWithConfig = new LinkedHashSet(){{ - add(AvgPooling2D.class.getName()); - add(Conv2D.class.getName()); - add(Conv3D.class.getName()); - add(LocalResponseNormalization.class.getName()); - add(MaxPooling2D.class.getName()); - add(Pooling2D.class.getName()); - add(Pooling3D.class.getName()); - add(DepthwiseConv2D.class.getName()); - add(DeConv2DTF.class.getName()); - }}; //When determining fields/properties, where should we terminate the search? //We don't wan to include every single field from every single superclass private static final Set classesToIgnore = new HashSet<>(Arrays.asList( @@ -165,16 +154,37 @@ public class DifferentialFunctionClassHolder { Map fieldNames = new LinkedHashMap<>(); Class current = df.getClass(); val fields = new ArrayList(); + boolean isFirst = true; + while (current.getSuperclass() != null && !classesToIgnore.contains(current.getSuperclass())) { - if (classesWithConfig.contains(current.getName())) { - val fieldName = "config"; + if (df.isConfigProperties() && isFirst) { - val configField = current.getDeclaredField(fieldName); - if (configField == null) { - continue; + String fieldName = df.configFieldName(); + + if(fieldName == null) + fieldName = "config"; + + Field configField = null; + try{ + configField = current.getDeclaredField(fieldName); + } catch (NoSuchFieldException e){ + Class currentConfig = current.getSuperclass(); + + // find a config field in superclasses + while(currentConfig.getSuperclass() != null){ + try { + configField = currentConfig.getDeclaredField(fieldName); + break; + } catch (NoSuchFieldException e2){ + currentConfig = currentConfig.getSuperclass(); + } + } } + if(configField == null) + continue; + val configFieldClass = configField.getType(); for (val field : configFieldClass.getDeclaredFields()) { @@ -206,6 +216,7 @@ public class DifferentialFunctionClassHolder { // do something with current's fields current = (Class) current.getSuperclass(); + isFirst = false; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 1c3bd8c89..e13c45294 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -347,14 +347,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace.class, org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet.class, org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThan.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThanOrEqual.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThanOrEqual.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin.class, - org.nd4j.linalg.api.ops.impl.transforms.comparison.OldNotEqualTo.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class, @@ -453,15 +445,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldFModOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldFloorDivOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldRDivOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldRSubOp.class, - org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp.class, @@ -493,8 +476,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.same.Max.class, org.nd4j.linalg.api.ops.impl.transforms.same.Min.class, org.nd4j.linalg.api.ops.impl.transforms.same.Negative.class, - org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity.class, - org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse.class, org.nd4j.linalg.api.ops.impl.transforms.same.OneMinus.class, org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal.class, org.nd4j.linalg.api.ops.impl.transforms.same.Round.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java index 1f8ec828b..92c888e0c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java @@ -361,7 +361,7 @@ public abstract class BaseGraphMapper= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); addArgs(); - sameDiff.putFunctionForId(this.getOwnName(), this); + sameDiff.putOpForId(this.getOwnName(), this); sameDiff.addArgsFor(inputFunctions, this); } @@ -113,12 +111,6 @@ public class Conv1D extends DynamicCustomOp { return config.toProperties(); } - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - addArgs(); - } - @Override public boolean isConfigProperties() { return true; @@ -129,107 +121,6 @@ public class Conv1D extends DynamicCustomOp { return "config"; } - @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { - OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); - addArgs(); - } - - - @Override - public Map> attributeAdaptersForFunction() { - Map> ret = new HashMap<>(); - Map tfMappings = new LinkedHashMap<>(); - val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this); - - - tfMappings.put("kH", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 2, 0, fields.get("dataFormat"))); - tfMappings.put("kW", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 3, 1, fields.get("dataFormat"))); - tfMappings.put("sH", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 2, 1, fields.get("dataFormat"))); - tfMappings.put("sW", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 3, 2, fields.get("dataFormat"))); - tfMappings.put("isSameMode", new StringEqualsAdapter("SAME")); - tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC")); - - - Map onnxMappings = new HashMap<>(); - onnxMappings.put("kH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0)); - onnxMappings.put("kW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0)); - onnxMappings.put("dH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0)); - onnxMappings.put("dW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0)); - onnxMappings.put("sH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0)); - onnxMappings.put("sW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0)); - onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME")); - onnxMappings.put("isNHWC", new StringEqualsAdapter("NHC")); - - ret.put(tensorflowName(), tfMappings); - ret.put(onnxName(), onnxMappings); - return ret; - } - - @Override - public Map> mappingsForFunction() { - Map> ret = new HashMap<>(); - Map map = new HashMap<>(); - val strideMapping = PropertyMapping.builder() - .tfAttrName("strides") - .onnxAttrName("strides") - .propertyNames(new String[]{"s"}) - .build(); - - val kernelMapping = PropertyMapping.builder() - .propertyNames(new String[]{"k"}) - .tfInputPosition(1) - .shapePosition(0) - .onnxAttrName("kernel_shape") - .build(); - - val paddingMapping = PropertyMapping.builder() - .onnxAttrName("padding") - .propertyNames(new String[]{"p"}) - .build(); - - val dataFormat = PropertyMapping.builder() - .onnxAttrName("data_format") - .tfAttrName("data_format") - .propertyNames(new String[]{"dataFormat"}) - .build(); - - val nhwc = PropertyMapping.builder() - .onnxAttrName("data_format") - .tfAttrName("data_format") - .propertyNames(new String[]{"isNHWC"}) - .build(); - - val sameMode = PropertyMapping.builder() - .onnxAttrName("auto_pad") - .propertyNames(new String[]{"isSameMode"}) - .tfAttrName("padding") - .build(); - - map.put("s", strideMapping); - map.put("k", kernelMapping); - map.put("p", paddingMapping); - map.put("isSameMode", sameMode); - map.put("dataFormat", dataFormat); - map.put("isNHWC", nhwc); - - try { - ret.put(onnxName(), map); - } catch (NoOpNameFoundException e) { - //ignore - } - - - try { - ret.put(tensorflowName(), map); - } catch (NoOpNameFoundException e) { - //ignore - } - - return ret; - } - - @Override public String opName() { return "conv1d"; @@ -241,16 +132,6 @@ public class Conv1D extends DynamicCustomOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "Conv1D"; - } - - @Override - public String[] tensorflowNames() { - return new String[]{"Conv1D"}; - } - @Override public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index c32d8a34d..4335f4561 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -70,7 +70,7 @@ public class Conv2D extends DynamicCustomOp { config.getSH(), config.getPH(), config.getDW()); addArgs(); if(sameDiff != null) { - sameDiff.putFunctionForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point + sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point sameDiff.addArgsFor(inputFunctions, this); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index 1c9b891f3..6cba853d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -68,7 +68,7 @@ public class DeConv2D extends DynamicCustomOp { } addArgs(); - sameDiff.putFunctionForId(this.getOwnName(), this); + sameDiff.putOpForId(this.getOwnName(), this); sameDiff.addArgsFor(inputs, this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java index 34b69054d..085f48365 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java @@ -21,7 +21,6 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -71,7 +70,7 @@ public class DeConv2DTF extends DynamicCustomOp { } addArgs(); - sameDiff.putFunctionForId(this.getOwnName(), this); + sameDiff.putOpForId(this.getOwnName(), this); sameDiff.addArgsFor(inputs, this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index dcad3b8b4..0ea84e081 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -62,7 +62,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { this.sameDiff = sameDiff; this.config = config; addArgs(); - sameDiff.putFunctionForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point + sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point sameDiff.addArgsFor(inputFunctions, this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index e71e05f77..de4e763bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -76,6 +76,16 @@ public class LocalResponseNormalization extends DynamicCustomOp { addIArgument(config.getDepth()); } + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName(){ + return "config"; + } + @Override public String opName() { return "lrn"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index 586a3b8ce..62e373832 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -65,7 +65,7 @@ public class TensorMmul extends DynamicCustomOp { this.sameDiff = sameDiff; this.mMulTranspose = mMulTranspose; this.axes = dimensions; - if(!addedEdges && sameDiff.getOutputsForFunction(this) == null) { + if(!addedEdges && sameDiff.getOutputsForOp(this) == null) { addedEdges = true; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index b1c4b34ad..e0b0450d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -151,7 +151,7 @@ public class Concat extends DynamicCustomOp { removeInputArgument(inputArgs[inputArguments().length - 1]); } - sameDiff.removeArgFromFunction(input,this); + sameDiff.removeArgFromOp(input,this); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java index 60fd6b94e..07bdab586 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java @@ -72,7 +72,7 @@ public class TensorArrayConcat extends BaseTensorOp { public List calculateOutputDataTypes(java.util.List inputDataType){ //Same output type as the TensorArray - which is defined by input 0 SDVariable tArr = arg(0); - TensorArray t3 = (TensorArray) sameDiff.getVariableOutputFunction(tArr.getVarName()); + TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); return Collections.singletonList(dt); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java index 3ce983079..3ab0d91c9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java @@ -72,7 +72,7 @@ public class TensorArrayGather extends BaseTensorOp { public List calculateOutputDataTypes(java.util.List inputDataType){ //Same output type as the TensorArray - which is defined by input 0 SDVariable tArr = arg(0); - TensorArray t3 = (TensorArray) sameDiff.getVariableOutputFunction(tArr.getVarName()); + TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); return Collections.singletonList(dt); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java index cbf51e523..619216813 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java @@ -72,7 +72,7 @@ public class TensorArrayRead extends BaseTensorOp { dt = importDataType; } else { SDVariable tArr = arg(0); - DifferentialFunction op = sameDiff.getVariableOutputFunction(tArr.getVarName()); + DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.getVarName()); TensorArray t3 = (TensorArray) op; dt = t3.getTensorArrayDataType(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index 15f044320..f940ac7a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -55,6 +56,14 @@ public class Pad extends DynamicCustomOp { addTArgument(padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ + super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); + Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); + this.mode = mode; + addIArgument(mode.ordinal()); + addTArgument(padValue); + } + @Override public String opName(){ return "pad"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldEqualTo.java deleted file mode 100644 index 94e38bcae..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldEqualTo.java +++ /dev/null @@ -1,95 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.Arrays; -import java.util.List; - -/** - * Bit mask over the ndarrays as to whether - * the components are equal or not - * - * @author Adam Gibson - */ -public class OldEqualTo extends BaseTransformBoolOp { - - - - public OldEqualTo(SameDiff sameDiff) { - super(sameDiff); - } - - public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, Object[] extraArgs) { - super(sameDiff, i_v1, i_v2, extraArgs); - } - - public OldEqualTo(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldEqualTo() {} - - public OldEqualTo(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldEqualTo(INDArray x, INDArray y) { - super(x, y, null); - } - - @Override - public int opNum() { - return 0; - } - - @Override - public String opName() { - return "oldeq"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No Tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List f1) { - //Equals op: 2 inputs, not continuously differentiable but 0s almost everywhere - return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldGreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldGreaterThan.java deleted file mode 100644 index e558ecfcc..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldGreaterThan.java +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.Arrays; -import java.util.List; - -/** - * Bit mask over the ndarrays as to whether - * the components are greater than or not - * - * @author Adam Gibson - */ -public class OldGreaterThan extends BaseTransformBoolOp { - public OldGreaterThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldGreaterThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldGreaterThan(SameDiff sameDiff) { - super(sameDiff); - } - - public OldGreaterThan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldGreaterThan() {} - - public OldGreaterThan(INDArray x, INDArray z) { - super(x, z); - } - - public OldGreaterThan(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldGreaterThan(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 1; - } - - @Override - public String opName() { - return "oldgt"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow name found"); - } - - - @Override - public List doDiff(List f1) { - return Arrays.asList(outputVariables()[0]); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldGreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldGreaterThanOrEqual.java deleted file mode 100644 index 4932209e6..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldGreaterThanOrEqual.java +++ /dev/null @@ -1,85 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.List; - -/** - * Bit mask over the ndarrays as to whether - * the components are greater than or equal or not - * - * @author Adam Gibson - */ -public class OldGreaterThanOrEqual extends BaseTransformBoolOp { - public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldGreaterThanOrEqual(SameDiff sameDiff) { - super(sameDiff); - } - - public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldGreaterThanOrEqual() {} - - public OldGreaterThanOrEqual(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldGreaterThanOrEqual(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 4; - } - - @Override - public String opName() { - return "oldgte"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List f1) { - return null; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldLessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldLessThan.java deleted file mode 100644 index 3e9ae2384..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldLessThan.java +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.Arrays; -import java.util.List; - -/** - * Bit mask over the ndarrays as to whether - * the components are less than or not - * - * @author Adam Gibson - */ -public class OldLessThan extends BaseTransformBoolOp { - public OldLessThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldLessThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldLessThan(SameDiff sameDiff) { - super(sameDiff); - } - - public OldLessThan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldLessThan() {} - - public OldLessThan(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldLessThan(INDArray x) { - super(x); - } - - public OldLessThan(INDArray x, INDArray z) { - super(x, z); - } - - @Override - public int opNum() { - return 2; - } - - @Override - public String opName() { - return "oldlt"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tf opName found for " + opName()); - } - - - @Override - public List doDiff(List f1) { - return Arrays.asList(outputVariables()[0]); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldLessThanOrEqual.java deleted file mode 100644 index 3c51d4293..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldLessThanOrEqual.java +++ /dev/null @@ -1,104 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.Arrays; -import java.util.List; - -/** - * Bit mask over the ndarrays as to whether - * the components are less than or equal or not - * - * @author Adam Gibson - */ -public class OldLessThanOrEqual extends BaseTransformBoolOp { - public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldLessThanOrEqual(SameDiff sameDiff) { - super(sameDiff); - } - - public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, Object[] extraArgs) { - super(sameDiff, i_v1, i_v2, extraArgs); - } - - public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - - public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) { - super(sameDiff, i_v, extraArgs); - } - - public OldLessThanOrEqual() {} - - public OldLessThanOrEqual(INDArray x, INDArray z) { - super(x, z); - } - - public OldLessThanOrEqual(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldLessThanOrEqual(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 5; - } - - @Override - public String opName() { - return "oldlte"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - - @Override - public List doDiff(List f1) { - return Arrays.asList(outputVariables()[0]); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldMax.java deleted file mode 100644 index 33f922ec9..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldMax.java +++ /dev/null @@ -1,88 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.List; - -/** - * Max function - * - * @author Adam Gibson - */ -public class OldMax extends BaseTransformSameOp { - public OldMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldMax(SameDiff sameDiff) { - super(sameDiff); - } - - public OldMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldMax() {} - - public OldMax(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldMax(INDArray x) { - super(x); - } - - public OldMax(INDArray ndArray, INDArray dup) { - super(ndArray, dup); - } - - @Override - public int opNum() { - return 7; - } - - @Override - public String opName() { - return "old_max_transform"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead"); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead"); - } - - @Override - public List doDiff(List f1) { - return null; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldMin.java deleted file mode 100644 index 5e8e749d0..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldMin.java +++ /dev/null @@ -1,88 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.List; - -/** - * Min function - * - * @author Adam Gibson - */ -public class OldMin extends BaseTransformSameOp { - public OldMin(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldMin(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldMin(SameDiff sameDiff) { - super(sameDiff); - } - - public OldMin(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldMin() {} - - public OldMin(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldMin(INDArray x) { - super(x); - } - - public OldMin(INDArray x, INDArray z) { - super(x, z); - } - - @Override - public int opNum() { - return 8; - } - - @Override - public String opName() { - return "old_min_transform"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead"); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead"); - } - - @Override - public List doDiff(List f1) { - return null; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldNotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldNotEqualTo.java deleted file mode 100644 index 0c8aa3d82..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/OldNotEqualTo.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.comparison; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.Arrays; -import java.util.List; - -/** - * Not equal to function: - * Bit mask over whether 2 elements are not equal or not - * - * @author Adam Gibson - */ -public class OldNotEqualTo extends BaseTransformBoolOp { - public OldNotEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldNotEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldNotEqualTo() { - } - - public OldNotEqualTo(INDArray x) { - super(x); - } - - public OldNotEqualTo(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldNotEqualTo(INDArray x, INDArray z) { - super(x, z); - } - - @Override - public int opNum() { - return 6; - } - - @Override - public String opName() { - return "old_neq"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No op name found"); - } - - - @Override - public List doDiff(List i_v) { - - return Arrays.asList(f().neg(i_v.get(0))); - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java index 8b25052be..8a782acf6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java @@ -22,11 +22,13 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.nd4j.linalg.ops.transforms.Transforms; /** * Arc Tangent elementwise function @@ -39,6 +41,15 @@ public class ATan2 extends BaseDynamicTransformOp { super(sameDiff, new SDVariable[] {y, x} ,false); } + /** + * Note that the order of x and y match {@link java.lang.Math#atan2(double, double)}, + * and are reversed when compared to OldATan2. + * See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)} + */ + public ATan2(INDArray x, INDArray y, INDArray z) { + super(new INDArray[]{x, y}, new INDArray[]{ z }); + } + public ATan2() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java index 762bba65d..b786602d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java @@ -46,6 +46,9 @@ public class EqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } + public EqualTo(INDArray x, INDArray y, INDArray z){ + this(new INDArray[]{x, y}, new INDArray[]{z}); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java index 64d3f2544..27b2ea189 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java @@ -47,7 +47,9 @@ public class GreaterThan extends BaseDynamicTransformOp { super(inputs, outputs); } - + public GreaterThan(INDArray x, INDArray y, INDArray z){ + this(new INDArray[]{x, y}, new INDArray[]{z}); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java index de5661313..48d3953aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java @@ -46,6 +46,10 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp { super(inputs, outputs); } + public GreaterThanOrEqual(INDArray x, INDArray y, INDArray z){ + this(new INDArray[]{x, y}, new INDArray[]{z}); + } + @Override public int opNum() { return 11; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java index 4e26249a6..0ee59458c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java @@ -47,6 +47,10 @@ public class LessThan extends BaseDynamicTransformOp { super(inputs, outputs); } + public LessThan(INDArray x, INDArray y, INDArray z){ + this(new INDArray[]{x, y}, new INDArray[]{z}); + } + @Override public String opName() { return "less"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java index b7ef72193..56a5882db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java @@ -45,6 +45,11 @@ public class LessThanOrEqual extends BaseDynamicTransformOp { public LessThanOrEqual( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } + + public LessThanOrEqual(INDArray x, INDArray y, INDArray z){ + this(new INDArray[]{x, y}, new INDArray[]{z}); + } + @Override public String opName() { return "less_equal"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java index 407c60e35..b1f0dadbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java @@ -46,6 +46,9 @@ public class NotEqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } + public NotEqualTo(INDArray x, INDArray y, INDArray z){ + this(new INDArray[]{x, y}, new INDArray[]{z}); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java index b59f67bef..b4aa329a6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -38,6 +39,27 @@ public class Reverse extends DynamicCustomOp { public Reverse() { } + /** + * Inplace reverse. See {@link #Reverse(INDArray, INDArray)} + */ + public Reverse(INDArray x){ + this(x, x); + this.inPlace = true; + } + + /** + * Reverses whole array for compatibility with OldReverse. + * + * Note that otherwise, passing null or empty dimensions will result in a noop. + */ + public Reverse(INDArray x, INDArray z){ + super(new INDArray[]{x}, new INDArray[]{z}); + this.dimensions = new int[x.rank()]; + for(int i = 0 ; i < this.dimensions.length ; i++) + this.dimensions[i] = i; + addIArgument(dimensions); + } + @Override public String opName() { return "reverse"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java index 9823c0566..c2314cdc4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java @@ -52,6 +52,9 @@ public class FModOp extends BaseTransformSameOp { public FModOp(INDArray x, INDArray z) { super(x, z); } + public FModOp(INDArray x, INDArray y, INDArray z) { + super(x, y, z); + } public FModOp(INDArray x) { super(x); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAddOp.java deleted file mode 100644 index 18b4e5912..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAddOp.java +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.ArrayList; -import java.util.List; - -/** - * @deprecated Use {@link AddOp} - */ -@Deprecated -public class OldAddOp extends BaseTransformAnyOp { - public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldAddOp() {} - - public OldAddOp(INDArray x) { - super(x); - } - - public OldAddOp(INDArray x, INDArray z) { - super(x, z); - } - - public OldAddOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 0; - } - - @Override - public String opName() { - return "old_add"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); - - List ret = new ArrayList<>(2); - ret.add(gradWrtX); - ret.add(gradWrtY); - return ret; - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAtan2Op.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAtan2Op.java deleted file mode 100644 index 5a3907fba..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldAtan2Op.java +++ /dev/null @@ -1,86 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.List; - -/** - * atan2 operation - * - * @author raver119@gmail.com - */ -public class OldAtan2Op extends BaseTransformAnyOp { - public OldAtan2Op(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldAtan2Op(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldAtan2Op(SameDiff sameDiff) { - super(sameDiff); - } - - public OldAtan2Op(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldAtan2Op() {} - - public OldAtan2Op(INDArray x, INDArray y) { - super(x, y, x); - } - - public OldAtan2Op(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 16; - } - - @Override - public String opName() { - return "old_atan2"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx opName found for " + opName()); - } - - @Override - public String tensorflowName() { - return "ATan2"; - } - - @Override - public List doDiff(List f1) { - return null; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldDivOp.java deleted file mode 100644 index 16f2a7761..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldDivOp.java +++ /dev/null @@ -1,88 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.ArrayList; -import java.util.List; - -/** - * @deprecated Use {@link DivOp} - */ -@Deprecated -public class OldDivOp extends BaseTransformAnyOp { - public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldDivOp() {} - - public OldDivOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - public OldDivOp(INDArray x) { - super(x); - } - - public OldDivOp(INDArray x, INDArray z) { - super(x, z); - } - - @Override - public int opNum() { - return 2; - } - - @Override - public String opName() { - return "olddiv"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); - List ret = new ArrayList<>(2); - ret.add(gradWrtX); - ret.add(gradWrtY); - return ret; - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldFModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldFModOp.java deleted file mode 100644 index 26f6664de..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldFModOp.java +++ /dev/null @@ -1,88 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.List; - -/** - * Floating point remainder - * - * @author raver119@gmail.com - */ -public class OldFModOp extends BaseTransformAnyOp { - public OldFModOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldFModOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldFModOp(SameDiff sameDiff) { - super(sameDiff); - } - - public OldFModOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldFModOp() {} - - public OldFModOp(INDArray x) { - super(x); - } - - public OldFModOp(INDArray x, INDArray z) { - super(x, z); - } - public OldFModOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 15; - } - - @Override - public String opName() { - return "oldfmod"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List f1) { - return null; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldFloorDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldFloorDivOp.java deleted file mode 100644 index a0c95c15d..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldFloorDivOp.java +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.ArrayList; -import java.util.List; - -/** - * Truncated division operation - * - * @author Adam Gibson - */ -public class OldFloorDivOp extends BaseTransformAnyOp { - public OldFloorDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldFloorDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldFloorDivOp() {} - - public OldFloorDivOp(INDArray x) { - super(x); - } - - public OldFloorDivOp(INDArray x, INDArray z) { - super(x, z); - } - - public OldFloorDivOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 18; - } - - @Override - public String opName() { - return "oldfloordiv"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); - List ret = new ArrayList<>(2); - ret.add(gradWrtX); - ret.add(gradWrtY); - return ret; - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldMulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldMulOp.java deleted file mode 100644 index d1e8b6bc6..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldMulOp.java +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.ArrayList; -import java.util.List; - -/** - * @deprecated Use {@link MulOp} - */ -@Deprecated -public class OldMulOp extends BaseTransformAnyOp { - public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldMulOp() {} - - public OldMulOp(INDArray x) { - super(x); - } - - public OldMulOp(INDArray x, INDArray z) { - super(x, z); - } - - public OldMulOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 3; - } - - @Override - public String opName() { - return "oldmul"; - } - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - @Override - public List doDiff(List i_v) { - SDVariable g = sameDiff.setupFunction(i_v.get(0)); - SDVariable gradWrtX = f().mul(g,rarg()); - SDVariable gradWrtY = f().mul(g,larg()); - List ret = new ArrayList<>(2); - ret.add(gradWrtX); - ret.add(gradWrtY); - return ret; - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRDivOp.java deleted file mode 100644 index 9838cfd37..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRDivOp.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.ArrayList; -import java.util.List; - -/** - * @deprecated Use {@link RDivOp} - */ -@Deprecated -public class OldRDivOp extends BaseTransformAnyOp { - public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldRDivOp() {} - - public OldRDivOp(INDArray x) { - super(x); - } - - public OldRDivOp(INDArray x, INDArray z) { - super(x, z); - } - - public OldRDivOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 11; - } - - @Override - public String opName() { - return "oldrdiv"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),larg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(rarg(),larg())); - List ret = new ArrayList<>(2); - ret.add(gradWrtX); - ret.add(gradWrtY); - return ret; - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRSubOp.java deleted file mode 100644 index 91d2cd90b..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldRSubOp.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; - -import java.util.ArrayList; -import java.util.List; - -/** - * @deprecated Use {@link RSubOp} - */ -@Deprecated -public class OldRSubOp extends BaseTransformAnyOp { - public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldRSubOp() {} - - public OldRSubOp(INDArray x) { - super(x); - } - - public OldRSubOp(INDArray x, INDArray z) { - super(x, z); - } - - public OldRSubOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 5; - } - - @Override - public String opName() { - return "old_rsub"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); - - List ret = new ArrayList<>(2); - ret.add(gradWrtX); - ret.add(gradWrtY); - return ret; - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldSubOp.java deleted file mode 100644 index 1f806c34a..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/OldSubOp.java +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformAnyOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.ArrayList; -import java.util.List; - -/** - * @deprecated Use {@link SubOp} - */ -@Deprecated -public class OldSubOp extends BaseTransformAnyOp { - public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldSubOp() {} - - public OldSubOp(INDArray x) { - super(x); - } - - public OldSubOp(INDArray x, INDArray z) { - super(x, z); - } - - public OldSubOp(INDArray x, INDArray y, INDArray z) { - super(x, y, z); - } - - @Override - public int opNum() { - return 6; - } - - @Override - public String opName() { - return "old_sub"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); - - List ret = new ArrayList<>(2); - ret.add(gradWrtX); - ret.add(gradWrtY); - return ret; - } - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index 31af5b0c2..3e548226f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -40,6 +41,10 @@ public class Identity extends BaseDynamicTransformOp { super(sd, new SDVariable[]{input}, false); } + public Identity(INDArray x, INDArray z){ + super(new INDArray[]{x}, new INDArray[]{z}); + } + public Identity(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/OldIdentity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/OldIdentity.java deleted file mode 100644 index 936c829c8..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/OldIdentity.java +++ /dev/null @@ -1,77 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.same; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.Arrays; -import java.util.List; -import java.util.UUID; - -/** - * Identity function - * - * @author Adam Gibson - */ -public class OldIdentity extends BaseTransformSameOp { - public OldIdentity(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldIdentity() { - } - - public OldIdentity(INDArray x, INDArray z) { - super(x, z); - } - - public OldIdentity(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 15; - } - - @Override - public String opName() { - return "old_identity"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("This op does not work with onnx."); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("This op does not work with tensorflow."); - } - - - @Override - public List doDiff(List i_v) { - return i_v; - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/OldReverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/OldReverse.java deleted file mode 100644 index 493e3c92d..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/OldReverse.java +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.same; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformSameOp; - -import java.util.Arrays; -import java.util.List; - -/** - * OldReverse op - */ -public class OldReverse extends BaseTransformSameOp { - public OldReverse(SameDiff sameDiff, SDVariable i_v, int... dimensions) { - super(sameDiff, i_v, false); - this.dimensions = dimensions; - } - - public OldReverse() { - } - - public OldReverse(INDArray x, INDArray z) { - super(x, z); - } - - public OldReverse(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 20; - } - - @Override - public String opName() { - return "old_reverse"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - @Override - public List doDiff(List f1) { - SDVariable ret = f().reverse(f1.get(0), dimensions); - return Arrays.asList(ret); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java index 35ee39644..6e79ca067 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.convolution; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.Pad.Mode; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -129,8 +130,7 @@ public class OldConvolution { long w = img.size(3); long outHeight = outSize(h, kh, sy, ph, coverAll); long outWidth = outSize(w, kw, sx, pw, coverAll); - INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, - Nd4j.PadMode.CONSTANT); + INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, Mode.CONSTANT, pval); INDArray ret = Nd4j.create(n, c, kh, kw, outHeight, outWidth); for (int i = 0; i < kh; i++) { //offset for the row based on the stride and output height diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/env/impl/OmpNumThreadsAction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/env/impl/OmpNumThreadsAction.java index 5ece067dd..fff0ddd84 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/env/impl/OmpNumThreadsAction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/env/impl/OmpNumThreadsAction.java @@ -35,7 +35,8 @@ public class OmpNumThreadsAction implements EnvironmentalAction { val skipper = System.getenv(ND4JEnvironmentVars.ND4J_SKIP_BLAS_THREADS); if (skipper == null) { // we infer num threads only if skipper undefined - Nd4j.setNumThreads(v); + // Nd4j.setNumThreads(v); + // method does not do anything anymore and was removed } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Broadcast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Broadcast.java index 2a4673242..41d9e74c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Broadcast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Broadcast.java @@ -20,6 +20,14 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.*; import org.nd4j.linalg.api.ops.impl.broadcast.bool.*; +import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo; +import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan; +import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Min; +import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.same.AMax; @@ -42,7 +50,7 @@ public class Broadcast { public static INDArray add(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldAddOp(x,y,z)); + return Nd4j.getExecutioner().exec(new AddOp(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastAddOp(x,y,z,dimensions)); @@ -66,7 +74,7 @@ public class Broadcast { public static INDArray div(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldDivOp(x,y,z)); + return Nd4j.getExecutioner().exec(new DivOp(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastDivOp(x,y,z,dimensions)); @@ -78,7 +86,7 @@ public class Broadcast { public static INDArray eq(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldEqualTo(x,y,z)); + return Nd4j.getExecutioner().exec(new EqualTo(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastEqualTo(x,y,z,dimensions)); } @@ -89,7 +97,7 @@ public class Broadcast { public static INDArray gt(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldGreaterThan(x,y,z)); + return Nd4j.getExecutioner().exec(new GreaterThan(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastGreaterThan(x,y,z,dimensions)); @@ -101,7 +109,7 @@ public class Broadcast { public static INDArray gte(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldGreaterThanOrEqual(x,y,z)); + return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastGreaterThanOrEqual(x,y,z,dimensions)); @@ -113,7 +121,7 @@ public class Broadcast { public static INDArray lt(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldLessThan(x,y,z)); + return Nd4j.getExecutioner().exec(new LessThan(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastLessThan(x,y,z,dimensions)); @@ -125,7 +133,7 @@ public class Broadcast { public static INDArray lte(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldLessThanOrEqual(x,y,z)); + return Nd4j.getExecutioner().exec(new LessThanOrEqual(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastLessThanOrEqual(x,y,z,dimensions)); @@ -137,7 +145,7 @@ public class Broadcast { public static INDArray mul(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldMulOp(x,y,z)); + return Nd4j.getExecutioner().exec(new MulOp(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastMulOp(x,y,z,dimensions)); @@ -149,7 +157,7 @@ public class Broadcast { public static INDArray neq(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldNotEqualTo(x,y,z)); + return Nd4j.getExecutioner().exec(new NotEqualTo(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastNotEqual(x,y,z,dimensions)); @@ -161,7 +169,7 @@ public class Broadcast { public static INDArray rdiv(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldRDivOp(x,y,z)); + return Nd4j.getExecutioner().exec(new RDivOp(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastRDivOp(x,y,z,dimensions)); @@ -173,7 +181,7 @@ public class Broadcast { public static INDArray rsub(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldSubOp(x,y,z)); + return Nd4j.getExecutioner().exec(new SubOp(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastRSubOp(x,y,z,dimensions)); @@ -185,7 +193,7 @@ public class Broadcast { public static INDArray sub(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldSubOp(x,y,z)); + return Nd4j.getExecutioner().exec(new SubOp(x,y,z))[0]; } return Nd4j.getExecutioner().exec(new BroadcastSubOp(x,y,z,dimensions)); @@ -197,7 +205,7 @@ public class Broadcast { public static INDArray max(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldMax(x,y,z)); + return Nd4j.getExecutioner().exec(new Max(x,y,z))[0]; } @@ -210,7 +218,7 @@ public class Broadcast { public static INDArray min(INDArray x, INDArray y, INDArray z, int... dimensions) { if(dimensions == null || dimensions.length == 0) { validateShapesNoDimCase(x,y,z); - return Nd4j.getExecutioner().exec(new OldMin(x,y,z)); + return Nd4j.getExecutioner().exec(new Min(x,y,z))[0]; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 9ad0cda08..683acc0d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -57,8 +57,10 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.shape.Diag; import org.nd4j.linalg.api.ops.impl.shape.DiagPart; import org.nd4j.linalg.api.ops.impl.shape.Stack; +import org.nd4j.linalg.api.ops.impl.transforms.Pad; +import org.nd4j.linalg.api.ops.impl.transforms.Pad.Mode; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse; import org.nd4j.linalg.api.ops.impl.shape.Tile; -import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.impl.*; import org.nd4j.linalg.api.rng.DefaultRandom; @@ -182,93 +184,69 @@ public class Nd4j { nd4j.initContext(); } - public enum PadMode { - CONSTANT, EDGE, LINEAR_RAMP, MAXIMUM, MEAN, MEDIAN, MINIMUM, REFLECT, SYMMETRIC, WRAP - - } - /** - * See {@link #pad(INDArray, int[][], List, PadMode)} with zero padding. (zeros for constantValues). + * See {@link #pad(INDArray, INDArray)}. Uses 0 padding. */ - public static INDArray pad(INDArray toPad, int[][] padWidth, PadMode padMode) { - return pad(toPad, padWidth, ArrayUtil.zerosMatrix(toPad.shape()), padMode); + public static INDArray pad(@NonNull INDArray toPad, @NonNull int[][] padWidth){ + return pad(toPad, Nd4j.createFromArray(padWidth)); } /** - * Pad the given ndarray to the size along each dimension + * See {@link #pad(INDArray, INDArray)}. Uses 0 padding, and uses padWidth for all dimensions. + */ + public static INDArray pad(@NonNull INDArray toPad, @NonNull int... padWidth){ + return pad(toPad, padWidth, Mode.CONSTANT, 0); + } + + /** + * See {@link #pad(INDArray, INDArray, Pad.Mode, double)} with zero padding (zeros for padValue). + */ + public static INDArray pad(INDArray toPad, INDArray padding) { + return pad(toPad, padding, Mode.CONSTANT, 0); + } + + /** + * See {@link #pad(INDArray, INDArray, Mode, double)}. + */ + public static INDArray pad(@NonNull INDArray toPad, @NonNull int[][] padWidth, @NonNull Pad.Mode padMode, double padValue){ + return pad(toPad, Nd4j.createFromArray(padWidth), padMode, padValue); + } + + /** + * See {@link #pad(INDArray, INDArray, Mode, double)}, uses padWidth for all dimensions. + */ + public static INDArray pad(@NonNull INDArray toPad, @NonNull int[] padWidth, @NonNull Pad.Mode padMode, double padValue){ + int[][] pads = new int[toPad.rank()][padWidth.length]; + for(int i = 0 ; i < pads.length ; i++){ + pads[i] = padWidth; + } + return pad(toPad, pads, padMode, padValue); + } + + /** + * Pad the given ndarray to the size along each dimension. + * * @param toPad the ndarray to pad * @param padWidth the width to pad along each dimension - * @param constantValues the values to append for each dimension * @param padMode the mode to pad in + * @param padValue the value used during padding. Only used when padMode is {@link Pad.Mode#CONSTANT}. * @return the padded ndarray * based on the specified mode */ - public static INDArray pad(INDArray toPad, int[][] padWidth, List constantValues, PadMode padMode) { - if (padMode == PadMode.CONSTANT) { - if (padWidth.length < toPad.rank()) - throw new IllegalArgumentException("Please specify a pad width for each dimension"); + public static INDArray pad(@NonNull INDArray toPad, @NonNull INDArray padWidth, @NonNull Pad.Mode padMode, double padValue) { - List sizes = new ArrayList<>(); - for (int i = 0; i < toPad.rank(); i++) { - sizes.add(padWidth[i]); - } + Preconditions.checkArgument(toPad.rank() == padWidth.size(0), + "Must provide padding values for each dimension. Expected %s pairs for a rank %s array, got %s", + toPad.rank(), toPad.rank(), padWidth.size(0)); - return padImpl(toPad, sizes, constantValues); + long[] newShape = new long[toPad.rank()]; + for(int i = 0 ; i < newShape.length ; i++){ + newShape[i] = toPad.size(i) + padWidth.getRow(i).sumNumber().intValue(); } - throw new UnsupportedOperationException(); - } + INDArray out = Nd4j.createUninitialized(toPad.dataType(), newShape); + Pad op = new Pad(toPad, padWidth, out, padMode, padValue); - /** - * See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth. - */ - public static INDArray pad(INDArray toPad, int[] padWidth, List constantValues, PadMode padMode) { - if (padMode == PadMode.CONSTANT) { - if (padWidth.length < toPad.rank()) - throw new IllegalArgumentException("Please specify a pad width for each dimension"); - - toPad = Nd4j.stripOnes(toPad); - - List sizes = new ArrayList<>(); - for (int i = 0; i < toPad.rank(); i++) { - sizes.add(padWidth); - } - - return padImpl(toPad, sizes, constantValues); - } - throw new UnsupportedOperationException(); - } - - // common code for pad(INDArray, int[], List, PadMode) and - // pad(INDArray, int[][], List, PadMode) - private static INDArray padImpl(INDArray toPad, List sizes, List constantValues){ - - INDArray ret = toPad; - for (int i = 0; i < toPad.rank(); i++) { - int[] pad = sizes.get(i); - double[] constant = constantValues.get(i); - int padBefore = pad[0]; - int padAfter = pad[1]; - if (constant.length < 2) { - double val = constant[0]; - constant = new double[2]; - constant[0] = val; - constant[1] = val; - } - - double beforeVal = constant[0]; - double afterVal = constant[1]; - ret = Nd4j.prepend(ret, padBefore, beforeVal, i); - ret = Nd4j.append(ret, padAfter, afterVal, i); - - } - return ret; - } - - /** - * See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth and zero padding. - */ - public static INDArray pad(INDArray toPad, int[] padWidth, PadMode padMode) { - return pad(toPad, padWidth, ArrayUtil.zerosMatrix(padWidth), padMode); + return Nd4j.getExecutioner().exec(op)[0]; } /** @@ -2639,7 +2617,7 @@ public class Nd4j { * @return the reversed matrix */ public static INDArray reverse(INDArray reverse) { - return Nd4j.getExecutioner().exec(new OldReverse(reverse)); + return Nd4j.getExecutioner().exec(new Reverse(reverse))[0]; } /** @@ -5961,28 +5939,7 @@ public class Nd4j { } } - - /** - * This method returns maximal allowed number of threads for Nd4j. - * If value wasn't set in advance, max(1, availableProcessor) will be returned - * @return maximal allowed number of threads - */ - public static int numThreads() { - val v = numThreads.get(); - if (v <= 0) - return Math.max(1, Runtime.getRuntime().availableProcessors() / 2); - else - return v; - } - - /** - * This method sets maximal allowed number of threads for Nd4j - * @param numthreads maximal allowed number of threads - */ - public static void setNumThreads(int numthreads) { - numThreads.set(numthreads); - } - + public static DataType defaultFloatingPointType() { return defaultFloatingPointDataType.get(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java index 99bcbab0a..c398dad72 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java @@ -20,7 +20,7 @@ import lombok.Data; import lombok.NonNull; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -104,7 +104,7 @@ public class AdaMaxUpdater implements GradientUpdater { //u = max(B_2 * u, |grad|) u.muli(config.getBeta2()); Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later - Nd4j.getExecutioner().exec(new OldMax(u, gradient, u)); + Nd4j.getExecutioner().exec(new Max(u, gradient, u)); double beta1t = FastMath.pow(config.getBeta1(), iteration + 1); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java index 364f0bec3..64a9a6f87 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java @@ -19,7 +19,7 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; @@ -105,6 +105,6 @@ public class NesterovsUpdater implements GradientUpdater { INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1)); gradient.assign(ret); */ - Nd4j.getExecutioner().exec(new OldAddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient)); + Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index de997ec82..1f4004cb2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -30,6 +30,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNot; import org.nd4j.linalg.api.ops.impl.shape.Cross; import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2; +import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.floating.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; @@ -37,7 +41,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; @@ -104,7 +107,7 @@ public class Transforms { public static INDArray reverse(INDArray x, boolean dup) { - return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x)); + return Nd4j.getExecutioner().exec(new Reverse(x, dup ? x.ulike() : x))[0]; } /** @@ -140,14 +143,15 @@ public class Transforms { /** * Atan2 operation, new INDArray instance will be returned - * Note the order of x and y parameters is opposite to that of java.lang.Math.atan2 + * Note the order of x and y parameters is opposite to that of {@link java.lang.Math#atan2(double, double)} * * @param x the abscissa coordinate * @param y the ordinate coordinate * @return the theta from point (r, theta) when converting (x,y) from to cartesian to polar coordinates */ public static INDArray atan2(@NonNull INDArray x, @NonNull INDArray y) { - return Nd4j.getExecutioner().exec(new OldAtan2Op(x, y, x.ulike())); + // Switched on purpose, to match OldATan2 (which the javadoc was written for) + return Nd4j.getExecutioner().exec(new ATan2(y, x, x.ulike()))[0]; } /** @@ -789,7 +793,7 @@ public class Transforms { * @return */ public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray, boolean dup) { - return exec(new OldLessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering()))); + return Nd4j.getExecutioner().exec(new LessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0]; } @@ -801,7 +805,7 @@ public class Transforms { * @return */ public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray, boolean dup) { - return exec(new OldGreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering()))); + return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0]; } @@ -986,7 +990,7 @@ public class Transforms { * @return */ public static INDArray identity(INDArray ndArray, boolean dup) { - return exec(dup ? new OldIdentity(ndArray, ndArray.ulike()) : new OldIdentity(ndArray)); + return Nd4j.getExecutioner().exec(dup ? new Identity(ndArray, ndArray.ulike()) : new Identity(ndArray, ndArray))[0]; } public static INDArray isMax(INDArray input, DataType dataType) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index c8cbff5f9..55b1a35e9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -16,6 +16,13 @@ package org.nd4j.autodiff.opvalidation; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Ignore; @@ -32,21 +39,20 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; - @Slf4j public class LayerOpValidation extends BaseOpValidation { public LayerOpValidation(Nd4jBackend backend) { @@ -311,7 +317,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable loss = sd.mean("loss", out); log.info("Starting test: " + msg); - TestCase tc = new TestCase(sd); + TestCase tc = new TestCase(sd).gradientCheck(true); String error = OpValidation.validate(tc); if (error != null) { failed.add(msg); @@ -344,7 +350,7 @@ public class LayerOpValidation extends BaseOpValidation { String msg = Arrays.toString(inSizeNCHW); - TestCase tc = new TestCase(sd).testName(msg); + TestCase tc = new TestCase(sd).gradientCheck(true).testName(msg); String error = OpValidation.validate(tc); if (error != null) { failed.add(msg); @@ -552,7 +558,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable loss = sd.standardDeviation("loss", out, true); log.info("Starting test: " + msg); - TestCase tc = new TestCase(sd); + TestCase tc = new TestCase(sd).gradientCheck(true); tc.testName(msg); String error = OpValidation.validate(tc); if (error != null) { @@ -660,7 +666,7 @@ public class LayerOpValidation extends BaseOpValidation { // System.out.println(sd.getFunction("grad").summary()); //Gradient check: - TestCase tc = new TestCase(sd); + TestCase tc = new TestCase(sd).gradientCheck(true); String err = OpValidation.validate(tc); assertNull(err); } @@ -705,7 +711,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable loss = out.std(true); //Gradient check: - TestCase tc = new TestCase(sd); + TestCase tc = new TestCase(sd).gradientCheck(true); String err = OpValidation.validate(tc); assertNull(err); } @@ -798,7 +804,7 @@ public class LayerOpValidation extends BaseOpValidation { exp.putScalar(next, max); } - assertNull(OpValidation.validate(new TestCase(sd) + assertNull(OpValidation.validate(new TestCase(sd).gradientCheck(true) .expected(outPool, exp))); } @@ -856,7 +862,7 @@ public class LayerOpValidation extends BaseOpValidation { } assertNull(OpValidation.validate(new TestCase(sd) - .expected(outPool, exp))); + .expected(outPool, exp).gradientCheck(true))); } @@ -887,16 +893,12 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig); - out = sd.nn().tanh("out", out); + out = sd.nn().tanh("loss", out).shape().rename("out"); - INDArray outArr = sd.execAndEndResult(); - val outShape = outArr.shape(); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; - assertArrayEquals(new long[]{mb, nIn, 4, 4, 4}, outShape); + INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L); - SDVariable loss = out.std(true); - //Gradient check: - TestCase tc = new TestCase(sd); + TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true); String err = OpValidation.validate(tc); assertNull(err); } @@ -927,12 +929,16 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig); - out = sd.nn().tanh("out", out); + out = sd.nn().tanh("loss", out).shape().rename("out"); + + sd.setLossVariables("loss"); - INDArray outArr = sd.execAndEndResult(); - val outShape = outArr.shape(); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; - assertArrayEquals(new long[]{mb, nIn, 27, 27, 27}, outShape); + INDArray outArr = Nd4j.createFromArray(mb, nIn, 27, 27, 27L); + + TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true); + String err = OpValidation.validate(tc); + assertNull(err); } @Test @@ -958,13 +964,58 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); - out = sd.nn().tanh("out", out); + out = sd.nn().tanh("loss", out).shape().rename("out"); + + sd.setLossVariables("loss"); - INDArray outArr = sd.execAndEndResult(); - INDArray iOut = out.getArr(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 - val outShape = outArr.shape(); - assertArrayEquals(new long[]{mb, nOut, 27}, outShape); + INDArray outArr = Nd4j.createFromArray(mb, nOut, 27L); + TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(false); + String err = OpValidation + .validate(tc); + assertNull(err); + } + + @Test + public void testConv1dForward(){ + int nIn = 2; + int nOut = 1; + int kernel = 3; + int batchSize = 10; + int sequenceSize = 5; + + SameDiff sd = SameDiff.create(); + + INDArray inArr = Nd4j.linspace(0, nIn * batchSize * sequenceSize, nIn * batchSize * sequenceSize) + .reshape(batchSize, nIn, sequenceSize); + + INDArray wArr = Nd4j.linspace(0, kernel * nIn * nOut, kernel * nIn * nOut) + .reshape(kernel, nIn, nOut); + + SDVariable in = sd.var("in", inArr); + SDVariable w = sd.var("w", wArr); + + SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build()); + + INDArray expected = Nd4j.createFromArray( + new double[][][]{ + {{82.42424f, 100.60606f, 118.78788f}}, + {{264.2424f, 282.4242f, 300.6061f}}, + {{446.0606f, 464.2424f, 482.424f}}, + {{627.8788f, 646.0606f, 664.2424f}}, + {{809.6970f, 827.8788f, 846.0606f}}, + {{991.5152f, 1009.69696f, 1027.8788f}}, + {{1173.3333f, 1191.5152f, 1209.6970f}}, + {{1355.1515f, 1373.3333f, 1391.5153f}}, + {{1536.9697f, 1555.1515f, 1573.3333f}}, + {{1718.7878f, 1736.9697f, 1755.1515f}} + } + ); + + TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.getVarName(), expected); + String err = OpValidation.validate(tc); + + assertNull(err); } @@ -1000,17 +1051,61 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig); - out = sd.nn().tanh("out", out); + out = sd.nn().tanh("loss", out).shape().rename("out"); + + sd.setLossVariables("loss"); - INDArray outArr = sd.execAndEndResult(); //Expected output size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27 //Expected output size, WITH same mode: out = in/stride - val outShape = outArr.shape(); - assertArrayEquals(new long[]{mb, nOut, 5, 5, 5}, outShape); + INDArray outArr = Nd4j.createFromArray(mb, nOut, 5, 5, 5L); - SDVariable loss = out.std(true); - //Gradient check: - TestCase tc = new TestCase(sd); + TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true); + String err = OpValidation + .validate(tc); + assertNull(err); + } + + @Test + public void testDeConv3dBasic() { + int nIn = 4; + int nOut = 3; + int kH = 2; + int kW = 2; + int kD = 2; + + int mb = 3; + int imgH = 5; + int imgW = 5; + int imgT = 5; + + SameDiff sd = SameDiff.create(); + INDArray inArr = Nd4j.rand(new long[]{mb, nIn, 5, 5, 5}); + INDArray wArr = Nd4j.rand(kD, kH, kW, nOut, nIn); + + SDVariable in = sd.var("in", inArr); + SDVariable w = sd.var("W", wArr); + + DeConv3DConfig conv3DConfig = DeConv3DConfig.builder() + .kH(kH).kW(kW).kD(kD) + .sD(1).sH(1).sW(1) + .dH(1).dW(1).dD(1) + .isSameMode(true) + .dataFormat(DeConv3DConfig.NCDHW) + .build(); + + SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig); + out = sd.nn().tanh("loss", out).shape().rename("out"); + + sd.setLossVariables("loss"); + + //Expected conv3d size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27 + //Expected conv3d size, WITH same mode: out = in/stride + // reversed this for deconv3d + INDArray outArr = Nd4j.createFromArray(new long[]{mb, nOut, imgT, imgH, imgW}); + + TestCase tc = new TestCase(sd) + .expectedOutput("out", outArr) + .gradientCheck(true); String err = OpValidation.validate(tc); assertNull(err); } @@ -1181,23 +1276,23 @@ public class LayerOpValidation extends BaseOpValidation { List failed = new ArrayList<>(); for (boolean ncdhw : new boolean[]{true, false}) { - int nIn = inSizeNCDHW[1]; - int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); + int nIn = inSizeNCDHW[1]; + int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", shape); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("in", shape); - SDVariable out; - String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); + SDVariable out; + String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); - SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] - SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); - out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() - .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) - .isSameMode(true) - .kH(2).kW(2).kD(2) - .sD(1).sH(1).sW(-1).dW(-1) - .build()); + SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] + SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); + out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() + .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) + .isSameMode(true) + .kH(2).kW(2).kD(2) + .sD(1).sH(1).sW(-1).dW(-1) + .build()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index f96f7c8bf..c266aa4bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -38,10 +38,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; import org.nd4j.linalg.api.ops.impl.shape.Cross; import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Min; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; @@ -1008,7 +1008,7 @@ public class TransformOpValidation extends BaseOpValidation { } - DifferentialFunction[] funcs = sd.functions(); + DifferentialFunction[] funcs = sd.ops(); String name = opName == null ? funcs[0].opName() : opName; @@ -1141,11 +1141,11 @@ public class TransformOpValidation extends BaseOpValidation { break; case 14: t = sd.max(in1, in2); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldMax(ia, ib, ia.dup()))); + tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]); break; case 15: t = sd.min(in1, in2); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldMin(ia, ib, ia.dup()))); + tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]); break; case 16: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); @@ -1199,7 +1199,7 @@ public class TransformOpValidation extends BaseOpValidation { } - DifferentialFunction[] funcs = sd.functions(); + DifferentialFunction[] funcs = sd.ops(); String name = (opName == null ? funcs[0].opName() : opName); String msg = "test: " + i + " - " + name; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java index 1b94afab4..380f9e881 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java @@ -188,11 +188,11 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName()); } - DifferentialFunction[] fOrig = sd.functions(); - DifferentialFunction[] fRestored = restored.functions(); + DifferentialFunction[] fOrig = sd.ops(); + DifferentialFunction[] fRestored = restored.ops(); assertEquals(fOrig.length, fRestored.length); - for (int j = 0; j < sd.functions().length; j++) { + for (int j = 0; j < sd.ops().length; j++) { assertEquals(fOrig[j].getClass(), fRestored[j].getClass()); } @@ -224,7 +224,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { sd.save(f2, withUpdaterState); SameDiff r2 = SameDiff.load(f2, withUpdaterState); assertEquals(varsOrig.size(), r2.variables().size()); - assertEquals(fOrig.length, r2.functions().length); + assertEquals(fOrig.length, r2.ops().length); assertEquals(sd.getLossVariables(), r2.getLossVariables()); //Save via stream: @@ -237,7 +237,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { try(InputStream is = new BufferedInputStream(new FileInputStream(f3))) { SameDiff r3 = SameDiff.load(is, withUpdaterState); assertEquals(varsOrig.size(), r3.variables().size()); - assertEquals(fOrig.length, r3.functions().length); + assertEquals(fOrig.length, r3.ops().length); assertEquals(sd.getLossVariables(), r3.getLossVariables()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java index 482ca5d4f..083fdbfe8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java @@ -19,7 +19,6 @@ package org.nd4j.autodiff.samediff; import lombok.extern.slf4j.Slf4j; import org.junit.Test; import org.nd4j.autodiff.samediff.transform.*; -import org.nd4j.base.Preconditions; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -58,17 +57,17 @@ public class GraphTransformUtilTests extends BaseNd4jTest { SDVariable sub = add.sub(add2); - assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(add.getVarName()))); - assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(add2.getVarName()))); - assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(sub.getVarName()))); + assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.getVarName()))); + assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); + assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); - assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(add.getVarName()))); - assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(add2.getVarName()))); - assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(sub.getVarName()))); + assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.getVarName()))); + assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); + assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); - assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputFunction(add.getVarName()))); - assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputFunction(add2.getVarName()))); - assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputFunction(sub.getVarName()))); + assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.getVarName()))); + assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.getVarName()))); + assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.getVarName()))); SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class)); @@ -77,11 +76,11 @@ public class GraphTransformUtilTests extends BaseNd4jTest { assertEquals(2, l.size()); SubGraph sg1 = l.get(0); - assertTrue(sg1.getRootNode() == sd.getVariableOutputFunction(add.getVarName())); + assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.getVarName())); assertEquals(0, sg1.getChildNodes().size()); SubGraph sg2 = l.get(1); - assertTrue(sg2.getRootNode() == sd.getVariableOutputFunction(add2.getVarName())); + assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.getVarName())); assertEquals(0, sg2.getChildNodes().size()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index becd05aa7..ddc1e24b7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -59,13 +59,13 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNorma import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing; import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor; import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Min; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -1759,11 +1759,11 @@ public class SameDiffTests extends BaseNd4jTest { break; case 7: t = sd.max(in1, in2); - expOut = Nd4j.getExecutioner().exec(new OldMax(ia, ib, ia.dup())); + expOut = Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]; break; case 8: t = sd.min(in1, in2); - expOut = Nd4j.getExecutioner().exec(new OldMin(ia, ib, ia.dup())); + expOut = Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]; break; case 9: ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 46300d1bc..ab86f829e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -16,7 +16,6 @@ package org.nd4j.imports.TFGraphs; -import com.google.common.primitives.Doubles; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.FilenameUtils; @@ -38,7 +37,6 @@ import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.base.Preconditions; import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,7 +44,6 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.function.BiFunction; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; @@ -301,7 +298,7 @@ public class TFGraphTestAllHelper { Map fns = graph.getOps(); List execOrder = listener.getOpNamesList(); for(String opName : execOrder){ - String[] outputs = graph.getOutputsForFunction(fns.get(opName).getOp()); + String[] outputs = graph.getOutputsForOp(fns.get(opName).getOp()); Collections.addAll(varNames, outputs); } @@ -334,8 +331,8 @@ public class TFGraphTestAllHelper { if(countExceeds > 0){ maxRE = relError.maxNumber().doubleValue(); //Find the op that this variable is produced by - op = graph.getVariableOutputFunction(varName); - opInputs = graph.getInputsForFunction(op); + op = graph.getVariableOutputOp(varName); + opInputs = graph.getInputsForOp(op); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index d79fb523c..c6878d1be 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -732,9 +732,9 @@ public class TensorFlowImportTest extends BaseNd4jTest { } val functions = new HashMap(); - for (val func: tg.functions()) { + for (val func: tg.ops()) { val ownName = func.getOwnName(); - val outName = func.outputVariables()[0].getVarName(); + String outName = func.outputVariables()[0].getVarName(); assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName)); assertEquals(ownName, outName); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 71406b70a..b948220fd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -72,12 +72,12 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse; import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy; -import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; @@ -5226,7 +5226,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - INDArray rev = Nd4j.getExecutioner().exec(new OldReverse(array, Nd4j.createUninitialized(array.length()))); + INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0]; assertEquals(exp, rev); } @@ -5236,7 +5236,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - INDArray rev = Nd4j.getExecutioner().exec(new OldReverse(array, Nd4j.createUninitialized(array.length()))); + INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0]; assertEquals(exp, rev); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java index 19085629b..e92f03c39 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java @@ -35,7 +35,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf; import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN; import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo; +import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.factory.Nd4j; @@ -276,7 +276,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.INT); val exp = new long[]{1, 0, 0, 1}; - val result = Nd4j.getExecutioner().exec(new OldEqualTo(arrayX, arrayY)); + val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0]; assertEquals(DataType.BOOL, result.dataType()); val arr = result.data().asLong(); @@ -369,13 +369,13 @@ public class MixedDataTypesTests extends BaseNd4jTest { val result = Nd4j.getExecutioner().exec(op); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = RuntimeException.class) public void testTypesValidation_2() { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG); val exp = new long[]{1, 0, 0, 1}; - val result = Nd4j.getExecutioner().exec(new OldEqualTo(arrayX, arrayY)); + val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0]; val arr = result.data().asLong(); assertArrayEquals(exp, arr); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index c6414002e..49391b74c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -45,7 +45,7 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange; @@ -205,7 +205,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 1.0); - opExecutioner.exec(new OldMulOp(x, xDup, x)); + opExecutioner.exec(new MulOp(x, xDup, x)); assertEquals(solution, x); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index f46d5e694..866c592b7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -55,7 +55,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange; @@ -236,7 +236,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); INDArray solution = Nd4j.valueArrayOf(5, 1.0); - opExecutioner.exec(new OldMulOp(x, xDup, x)); + opExecutioner.exec(new MulOp(x, xDup, x)); assertEquals(solution, x); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java index d2c643461..6c98c5afb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java @@ -72,8 +72,9 @@ public class PaddingTests extends BaseNd4jTest { @Test public void testPad() { + INDArray start = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); - INDArray ret = Nd4j.pad(start, new int[] {5, 5}, Nd4j.PadMode.CONSTANT); + INDArray ret = Nd4j.pad(start, 5, 5); double[][] data = new double[][] {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java index 1b8b5e455..7e9f1e91c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java @@ -64,8 +64,7 @@ public class PaddingTestsC extends BaseNd4jTest { INDArray ret = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8}); - INDArray padded = Nd4j.pad(ret, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, - Nd4j.PadMode.CONSTANT); + INDArray padded = Nd4j.pad(ret, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}); INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, @@ -104,8 +103,7 @@ public class PaddingTestsC extends BaseNd4jTest { // FIXME: int cast int outWidth = Convolution.outSize((int) h, kh, sy, ph, 1, true); int outHeight = Convolution.outSize((int) w, kw, sx, pw, 1, true); - INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, - Nd4j.PadMode.CONSTANT); + INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}); System.out.println(padded); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index 834ad5689..9593b5a3b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -127,8 +127,7 @@ public class IndexingTestsC extends BaseNd4jTest { 4, 4, 4, 4, 4, 4, 4, 4}, new long[] {1, 1, 8, 8}); - INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, - Nd4j.PadMode.CONSTANT); + INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}); INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim), NDArrayIndex.interval(j, sx, jLim));