[WIP] Various fixes, mostly SameDiff/Nd4j (#110)

* Nd4j pad update

Signed-off-by: Ryan Nett <rnett@skymind.io>

* switched from guava Immutables to Collections.unmodifiableList/Map

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc

Signed-off-by: Ryan Nett <rnett@skymind.io>

* use new pad

Signed-off-by: Ryan Nett <rnett@skymind.io>

* conv tests use OpValidation

Signed-off-by: Ryan Nett <rnett@skymind.io>

* deconv3d overrides

Signed-off-by: Ryan Nett <rnett@skymind.io>

* test fix for the new pad method

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more test fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more test fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* rename SameDiff function methods to op (except for the actual SameDiff function ones)

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more pad overloads, test fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* test updates

Signed-off-by: Ryan Nett <rnett@skymind.io>

* conv1d test

Signed-off-by: Ryan Nett <rnett@skymind.io>

* remove Conv1D tf import (there isn't a TF conv1d op)

Signed-off-by: Ryan Nett <rnett@skymind.io>

* remove numThreads from Nd4j

Signed-off-by: Ryan Nett <rnett@skymind.io>

* replace Old ops with their newer versions, deprecate ones that haven't already been deprecated

Signed-off-by: Ryan Nett <rnett@skymind.io>

* remove use of setNumThreads

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix for Reverse and ATan2

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix test for wrong equals type

Signed-off-by: Ryan Nett <rnett@skymind.io>

* well it works now

Signed-off-by: Ryan Nett <rnett@skymind.io>

* better javadocs

Signed-off-by: Ryan Nett <rnett@skymind.io>

* NonNulls

Signed-off-by: Ryan Nett <rnett@skymind.io>

* better array literal

Signed-off-by: Ryan Nett <rnett@skymind.io>

* re-add tf import stuff (will remove later)

Signed-off-by: Ryan Nett <rnett@skymind.io>

* conv1d config load fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* partial config usage changes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* remove Old op classes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* config property fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* removed one too many ops

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-08-21 16:40:32 -07:00 committed by GitHub
parent eea3062ccf
commit 2b0d7b3b52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
98 changed files with 696 additions and 2244 deletions

View File

@ -24,11 +24,10 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut; import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
@ -139,7 +138,7 @@ public class AlphaDropout implements IDropout {
//a * (x * d + alphaPrime * (1-d)) + b //a * (x * d + alphaPrime * (1-d)) + b
INDArray inverseMask = mask.rsub(1.0); INDArray inverseMask = mask.rsub(1.0);
INDArray aPOneMinusD = inverseMask.muli(alphaPrime); INDArray aPOneMinusD = inverseMask.muli(alphaPrime);
Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, mask, output)); //out = x * d Nd4j.getExecutioner().exec(new MulOp(inputActivations, mask, output)); //out = x * d
output.addi(aPOneMinusD).muli(a).addi(b); output.addi(aPOneMinusD).muli(a).addi(b);
//Nd4j.getExecutioner().exec(new AlphaDropOut(inputActivations, output, p, a, alphaPrime, b)); //Nd4j.getExecutioner().exec(new AlphaDropOut(inputActivations, output, p, a, alphaPrime, b));
@ -152,7 +151,7 @@ public class AlphaDropout implements IDropout {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// dOut/dIn = 0 if dropped (d=0), or a otherwise (d=1) // dOut/dIn = 0 if dropped (d=0), or a otherwise (d=1)
mask.muli(a); mask.muli(a);
Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, mask, gradAtInput)); Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, mask, gradAtInput));
mask = null; mask = null;
return gradAtInput; return gradAtInput;
} }

View File

@ -24,7 +24,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
@ -153,7 +153,7 @@ public class Dropout implements IDropout {
mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()).assign(1.0); mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()).assign(1.0);
Nd4j.getExecutioner().exec(new DropOutInverted(mask, mask, currP)); Nd4j.getExecutioner().exec(new DropOutInverted(mask, mask, currP));
Nd4j.getExecutioner().exec(new OldMulOp(inputCast, mask, output)); Nd4j.getExecutioner().exec(new MulOp(inputCast, mask, output));
return output; return output;
} }
@ -171,7 +171,7 @@ public class Dropout implements IDropout {
if(m.dataType() != gradAtInput.dataType()){ if(m.dataType() != gradAtInput.dataType()){
m = m.castTo(gradAtInput.dataType()); m = m.castTo(gradAtInput.dataType());
} }
Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, m, gradAtInput)); Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, m, gradAtInput));
mask = null; mask = null;
return gradAtInput; return gradAtInput;
} }

View File

@ -22,7 +22,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
@ -88,7 +88,7 @@ public class GaussianDropout implements IDropout {
noise = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), inputActivations.shape(), inputActivations.ordering()); noise = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), inputActivations.shape(), inputActivations.ordering());
Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 1.0, stdev)); Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 1.0, stdev));
return Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, noise, output)); return Nd4j.getExecutioner().exec(new MulOp(inputActivations, noise, output))[0];
} }
@Override @Override
@ -96,7 +96,7 @@ public class GaussianDropout implements IDropout {
Preconditions.checkState(noise != null, "Cannot perform backprop: GaussianDropout noise array is absent (already cleared?)"); Preconditions.checkState(noise != null, "Cannot perform backprop: GaussianDropout noise array is absent (already cleared?)");
//out = in*y, where y ~ N(1, stdev) //out = in*y, where y ~ N(1, stdev)
//dL/dIn = dL/dOut * dOut/dIn = y * dL/dOut //dL/dIn = dL/dOut * dOut/dIn = y * dL/dOut
Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, noise, gradAtInput)); Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, noise, gradAtInput));
noise = null; noise = null;
return gradAtInput; return gradAtInput;
} }

View File

@ -19,7 +19,7 @@ package org.deeplearning4j.nn.conf.dropout;
import lombok.Data; import lombok.Data;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
@ -69,7 +69,7 @@ public class GaussianNoise implements IDropout {
INDArray noise = Nd4j.createUninitialized(output.dataType(), inputActivations.shape(), inputActivations.ordering()); INDArray noise = Nd4j.createUninitialized(output.dataType(), inputActivations.shape(), inputActivations.ordering());
Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 0, currS)); Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 0, currS));
Nd4j.getExecutioner().exec(new OldAddOp(inputActivations, noise, output)); Nd4j.getExecutioner().exec(new AddOp(inputActivations, noise, output));
return output; return output;
} }

View File

@ -24,7 +24,7 @@ import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
@ -144,7 +144,7 @@ public class BernoulliReconstructionDistribution implements ReconstructionDistri
INDArray out = Nd4j.createUninitialized(DataType.BOOL, p.shape()); INDArray out = Nd4j.createUninitialized(DataType.BOOL, p.shape());
Nd4j.getExecutioner().execAndReturn(new OldLessThan(rand, p, out)); Nd4j.getExecutioner().execAndReturn(new LessThan(rand, p, out));
return out.castTo(DataType.FLOAT); return out.castTo(DataType.FLOAT);
} }

View File

@ -22,8 +22,8 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.Distributions; import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
@ -86,9 +86,9 @@ public class WeightNoise implements IWeightNoise {
INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering()); INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering());
if (additive) { if (additive) {
Nd4j.getExecutioner().exec(new OldAddOp(param, noise,out)); Nd4j.getExecutioner().exec(new AddOp(param, noise,out));
} else { } else {
Nd4j.getExecutioner().exec(new OldMulOp(param, noise, out)); Nd4j.getExecutioner().exec(new MulOp(param, noise, out));
} }
return out; return out;
} }

View File

@ -34,8 +34,8 @@ import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -205,7 +205,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
INDArray batchMean = helper.getMeanCache(dataType); INDArray batchMean = helper.getMeanCache(dataType);
INDArray batchVar = helper.getVarCache(dataType); INDArray batchVar = helper.getVarCache(dataType);
Nd4j.getExecutioner().exec(new OldSubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean Nd4j.getExecutioner().exec(new SubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean
dGlobalMeanView.muli(1-layerConf().getDecay()); dGlobalMeanView.muli(1-layerConf().getDecay());
if(layerConf().isUseLogStd()){ if(layerConf().isUseLogStd()){
@ -219,12 +219,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
double decay = layerConf().getDecay(); double decay = layerConf().getDecay();
INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay)); INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay));
Nd4j.getExecutioner().exec(new OldDivOp(vari, varip1, dGlobalLog10StdView)); Nd4j.getExecutioner().exec(new DivOp(vari, varip1, dGlobalLog10StdView));
Transforms.log(dGlobalLog10StdView, false); Transforms.log(dGlobalLog10StdView, false);
dGlobalLog10StdView.muli(ONE_ON_2LOGE_10); dGlobalLog10StdView.muli(ONE_ON_2LOGE_10);
} else { } else {
//Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3 //Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3
Nd4j.getExecutioner().exec(new OldSubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar Nd4j.getExecutioner().exec(new SubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar
dGlobalVarView.muli(1 - layerConf().getDecay()); dGlobalVarView.muli(1 - layerConf().getDecay());
} }
@ -343,7 +343,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
And use the same idea for global variance estimate And use the same idea for global variance estimate
*/ */
Nd4j.getExecutioner().exec(new OldSubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean Nd4j.getExecutioner().exec(new SubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean
dGlobalMeanView.muli(1-layerConf().getDecay()); dGlobalMeanView.muli(1-layerConf().getDecay());
if(layerConf().isUseLogStd()){ if(layerConf().isUseLogStd()){
@ -357,12 +357,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
double decay = layerConf().getDecay(); double decay = layerConf().getDecay();
INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay)); INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay));
Nd4j.getExecutioner().exec(new OldDivOp(vari, varip1, dGlobalLog10StdView)); Nd4j.getExecutioner().exec(new DivOp(vari, varip1, dGlobalLog10StdView));
Transforms.log(dGlobalLog10StdView, false); Transforms.log(dGlobalLog10StdView, false);
dGlobalLog10StdView.muli(ONE_ON_2LOGE_10); dGlobalLog10StdView.muli(ONE_ON_2LOGE_10);
} else { } else {
//Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3 //Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3
Nd4j.getExecutioner().exec(new OldSubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar Nd4j.getExecutioner().exec(new SubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar
dGlobalVarView.muli(1 - layerConf().getDecay()); dGlobalVarView.muli(1 - layerConf().getDecay());
} }

View File

@ -23,10 +23,9 @@ import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLocalResponseNormalizationHelper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
@ -187,7 +186,7 @@ public class LocalResponseNormalization
// gx = gy * unitScale**-beta - 2 * alpha * beta * sumPart/unitScale * a^i_{x,y} - rearranged for more in-place ops // gx = gy * unitScale**-beta - 2 * alpha * beta * sumPart/unitScale * a^i_{x,y} - rearranged for more in-place ops
INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.shape(), epsilon.ordering()); INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.shape(), epsilon.ordering());
Nd4j.getExecutioner().exec(new OldMulOp(epsilon, scale, nextEpsilon)); Nd4j.getExecutioner().exec(new MulOp(epsilon, scale, nextEpsilon));
nextEpsilon.subi(sumPart.muli(input).divi(unitScale).muli(2 * alpha * beta)); nextEpsilon.subi(sumPart.muli(input).divi(unitScale).muli(2 * alpha * beta));
return new Pair<>(retGradient, nextEpsilon); return new Pair<>(retGradient, nextEpsilon);
} }
@ -257,7 +256,7 @@ public class LocalResponseNormalization
unitScale = sumPart.mul(alpha).addi(k); unitScale = sumPart.mul(alpha).addi(k);
// y = x * unitScale**-beta // y = x * unitScale**-beta
scale = Transforms.pow(unitScale, -beta, true); scale = Transforms.pow(unitScale, -beta, true);
Nd4j.getExecutioner().exec(new OldMulOp(input, scale, activations)); Nd4j.getExecutioner().exec(new MulOp(input, scale, activations));
} else { } else {
// unitScale = (k + alpha * sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 ) // unitScale = (k + alpha * sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 )
sumPart.muli(alpha, activations).addi(k); sumPart.muli(alpha, activations).addi(k);

View File

@ -35,7 +35,7 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus; import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -554,7 +554,7 @@ public class LSTMHelpers {
//Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi //Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
INDArray deltao = deltaoNext; INDArray deltao = deltaoNext;
Nd4j.getExecutioner().exec(new OldMulOp(nablaOut, sigmahOfS, deltao)); Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));
if (sigmoidGates) { if (sigmoidGates) {
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().exec(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().exec(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo
deltao.muli(sigmaoPrimeOfZo); deltao.muli(sigmaoPrimeOfZo);
@ -607,7 +607,7 @@ public class LSTMHelpers {
deltag.muli(ai); deltag.muli(ai);
deltag.muli(nablaCellState); deltag.muli(nablaCellState);
} else { } else {
INDArray temp2 = Nd4j.getExecutioner().exec(new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), ai.shape(), 'f'))); INDArray temp2 = Nd4j.getExecutioner().exec(new MulOp(ai, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), ai.shape(), 'f')))[0];
deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst()); deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
//TODO activation functions with params; optimize (no assign) //TODO activation functions with params; optimize (no assign)
} }
@ -616,7 +616,7 @@ public class LSTMHelpers {
//Network input delta: //Network input delta:
INDArray zi = fwdPass.iz[time]; INDArray zi = fwdPass.iz[time];
INDArray deltai = deltaiNext; INDArray deltai = deltaiNext;
temp = Nd4j.getExecutioner().exec(new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), deltai.shape(), 'f'))); temp = Nd4j.getExecutioner().exec(new MulOp(ag, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), deltai.shape(), 'f')))[0];
deltai.assign(afn.backprop(zi, temp).getFirst()); deltai.assign(afn.backprop(zi, temp).getFirst());
//TODO activation functions with params; also: optimize this (no assign) //TODO activation functions with params; also: optimize this (no assign)
//Shape: [m,n^L] //Shape: [m,n^L]

View File

@ -26,7 +26,6 @@ import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.graph.DataType;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
@ -520,7 +519,7 @@ public abstract class DifferentialFunction {
* @return the arguments for a given function * @return the arguments for a given function
*/ */
public SDVariable[] args() { public SDVariable[] args() {
return sameDiff.getInputVariablesForFunction(this); return sameDiff.getInputVariablesForOp(this);
} }
/** /**
@ -661,7 +660,7 @@ public abstract class DifferentialFunction {
} }
if(sameDiff != null && !(this instanceof SDVariable)) if(sameDiff != null && !(this instanceof SDVariable))
sameDiff.putFunctionForId(ownName,this); sameDiff.putOpForId(ownName,this);
} }
} }

View File

@ -23,6 +23,7 @@ import com.google.common.collect.Lists;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -39,12 +40,12 @@ import org.nd4j.evaluation.IMetric;
@Getter @Getter
public class EvaluationRecord { public class EvaluationRecord {
private ImmutableMap<String, List<IEvaluation>> evaluations; private Map<String, List<IEvaluation>> evaluations;
private Map<Class<? extends IEvaluation>, IEvaluation> classEvaluations = new HashMap<>(); private Map<Class<? extends IEvaluation>, IEvaluation> classEvaluations = new HashMap<>();
private boolean isEmpty = true; private boolean isEmpty = true;
public EvaluationRecord(Map<String, List<IEvaluation>> evaluations) { public EvaluationRecord(Map<String, List<IEvaluation>> evaluations) {
this.evaluations = ImmutableMap.copyOf(evaluations); this.evaluations = Collections.unmodifiableMap(evaluations);
for (List<IEvaluation> le : evaluations.values()) { for (List<IEvaluation> le : evaluations.values()) {
for (IEvaluation e : le) { for (IEvaluation e : le) {
@ -68,7 +69,7 @@ public class EvaluationRecord {
/** /**
* Get all evaluations * Get all evaluations
*/ */
public ImmutableMap<String, List<IEvaluation>> evaluations() { public Map<String, List<IEvaluation>> evaluations() {
return evaluations; return evaluations;
} }

View File

@ -16,8 +16,8 @@
package org.nd4j.autodiff.listeners.records; package org.nd4j.autodiff.listeners.records;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import lombok.Getter; import lombok.Getter;
@ -49,11 +49,11 @@ public class History {
public History(List<EvaluationRecord> training, List<EvaluationRecord> validation, LossCurve loss, public History(List<EvaluationRecord> training, List<EvaluationRecord> validation, LossCurve loss,
long trainingTimeMillis, List<Long> validationTimesMillis){ long trainingTimeMillis, List<Long> validationTimesMillis){
trainingHistory = ImmutableList.copyOf(training); trainingHistory = Collections.unmodifiableList(training);
validationHistory = ImmutableList.copyOf(validation); validationHistory = Collections.unmodifiableList(validation);
this.lossCurve = loss; this.lossCurve = loss;
this.trainingTimeMillis = trainingTimeMillis; this.trainingTimeMillis = trainingTimeMillis;
this.validationTimesMillis = ImmutableList.copyOf(validationTimesMillis); this.validationTimesMillis = Collections.unmodifiableList(validationTimesMillis);
} }
/** /**

View File

@ -16,8 +16,8 @@
package org.nd4j.autodiff.listeners.records; package org.nd4j.autodiff.listeners.records;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
@ -35,7 +35,7 @@ public class LossCurve {
private INDArray lossValues; private INDArray lossValues;
public LossCurve(List<Loss> losses){ public LossCurve(List<Loss> losses){
lossNames = ImmutableList.copyOf(losses.get(0).getLossNames()); lossNames = Collections.unmodifiableList(losses.get(0).getLossNames());
int numLossValues = losses.get(0).lossValues().length; int numLossValues = losses.get(0).lossValues().length;
lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length); lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length);

View File

@ -466,8 +466,11 @@ public class SameDiff extends SDBaseOps {
} }
/** /**
* Set the current {@link Listener} instances. * Set the current SameDiff-wide {@link Listener} instances.
* Note that *
* Note that this will overwrite the current listener list.
* If you want to use additional listeners for a single operation,
* use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
* *
* @param listeners Listeners * @param listeners Listeners
*/ */
@ -476,19 +479,37 @@ public class SameDiff extends SDBaseOps {
addListeners(listeners); addListeners(listeners);
} }
/**
* See {@link #setListeners(Listener...)}.
*/
public void setListeners(Collection<? extends Listener> listeners) { public void setListeners(Collection<? extends Listener> listeners) {
this.listeners.clear(); this.listeners.clear();
addListeners(listeners); addListeners(listeners);
} }
/**
* Add SameDiff-wide {@link Listener} instances.
*
* If you want to use additional listeners for a single operation,
* use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
*
* @param listeners Listeners
*/
public void addListeners(Listener... listeners) { public void addListeners(Listener... listeners) {
addListeners(Arrays.asList(listeners)); addListeners(Arrays.asList(listeners));
} }
/**
* See {@link #addListeners(Listener...)}.
*/
public void addListeners(Collection<? extends Listener> listeners) { public void addListeners(Collection<? extends Listener> listeners) {
this.listeners.addAll(listeners); this.listeners.addAll(listeners);
} }
/**
* Gets the current SameDiff-wide listeners.
*/
public List<Listener> getListeners() { public List<Listener> getListeners() {
return listeners; return listeners;
} }
@ -585,6 +606,9 @@ public class SameDiff extends SDBaseOps {
} }
/**
* Gets all operations in a given name scope.
*/
public List<SameDiffOp> getOpsInScope(NameScope scope) { public List<SameDiffOp> getOpsInScope(NameScope scope) {
ArrayList<SameDiffOp> ops = new ArrayList<>(); ArrayList<SameDiffOp> ops = new ArrayList<>();
for (SameDiffOp v : this.ops.values()) { for (SameDiffOp v : this.ops.values()) {
@ -594,6 +618,16 @@ public class SameDiff extends SDBaseOps {
return ops; return ops;
} }
/**
* See {@link #getOpsInScope(NameScope)}.
*/
public List<SameDiffOp> getOpsInScope(String scope){
return getOpsInScope(new NameScope(this, scope));
}
/**
* Gets all variables in a given name scope.
*/
public List<SDVariable> getVariablesInScope(NameScope scope) { public List<SDVariable> getVariablesInScope(NameScope scope) {
ArrayList<SDVariable> vars = new ArrayList<>(); ArrayList<SDVariable> vars = new ArrayList<>();
for (SDVariable v : variables()) { for (SDVariable v : variables()) {
@ -603,6 +637,13 @@ public class SameDiff extends SDBaseOps {
return vars; return vars;
} }
/**
* See {@link #getVariablesInScope(NameScope)}.
*/
public List<SDVariable> getVariablesInScope(String scope){
return getVariablesInScope(new NameScope(this, scope));
}
/** /**
* @param sameDiff * @param sameDiff
* @return * @return
@ -638,8 +679,8 @@ public class SameDiff extends SDBaseOps {
function.getSameDiff()); function.getSameDiff());
clone.setSameDiff(sameDiff); clone.setSameDiff(sameDiff);
clone.setOwnName(function.getOwnName()); clone.setOwnName(function.getOwnName());
if (sameDiff.functionExists(function.getOwnName())) if (sameDiff.opExists(function.getOwnName()))
sameDiff.putFunctionForId(function.getOwnName(), function); sameDiff.putOpForId(function.getOwnName(), function);
newFunctions.put(function.getOwnName(), clone); newFunctions.put(function.getOwnName(), clone);
val argsForFunction = function.args(); val argsForFunction = function.args();
@ -672,17 +713,21 @@ public class SameDiff extends SDBaseOps {
* @param id the function id to test for * @param id the function id to test for
* @return true if the function id exists, false otherwise * @return true if the function id exists, false otherwise
*/ */
public boolean functionExists(String id) { public boolean opExists(String id) {
return ops.containsKey(id); return ops.containsKey(id);
} }
public DifferentialFunction functionOutputFor(String varName) { /**
if (variables.get(varName).getOutputOfOp() == null) * Get the differential function (if any) that this variable is the output for
*
* @param variableName Name of the variable
* @return The differential function that this variable is an output of, or null if it is not the output of a function
*/
public DifferentialFunction getVariableOutputOp(String variableName) {
Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
if (variables.get(variableName).getOutputOfOp() == null)
return null; return null;
String outName = variables.get(varName).getOutputOfOp(); return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
if (outName == null)
return null;
return ops.get(outName).getOp();
} }
/** /**
@ -691,7 +736,7 @@ public class SameDiff extends SDBaseOps {
* @param id the id of the function * @param id the id of the function
* @return the function for the given id if it exists * @return the function for the given id if it exists
*/ */
public DifferentialFunction getFunctionById(@NonNull String id) { public DifferentialFunction getOpById(@NonNull String id) {
if (!ops.containsKey(id)) { if (!ops.containsKey(id)) {
throw new ND4JIllegalStateException("No function with id " + id + " found!"); throw new ND4JIllegalStateException("No function with id " + id + " found!");
} }
@ -705,7 +750,7 @@ public class SameDiff extends SDBaseOps {
* @param id the id of the function * @param id the id of the function
* @param function the function * @param function the function
*/ */
public void putFunctionForId(String id, DifferentialFunction function) { public void putOpForId(String id, DifferentialFunction function) {
if (ops.containsKey(id) && ops.get(id).getOp() == null) { if (ops.containsKey(id) && ops.get(id).getOp() == null) {
throw new ND4JIllegalStateException("Function by id already exists!"); throw new ND4JIllegalStateException("Function by id already exists!");
} else if (function instanceof SDVariable) { } else if (function instanceof SDVariable) {
@ -726,7 +771,7 @@ public class SameDiff extends SDBaseOps {
* @param function the function to get the inputs for * @param function the function to get the inputs for
* @return the input ids for a given function * @return the input ids for a given function
*/ */
public String[] getInputsForFunction(DifferentialFunction function) { public String[] getInputsForOp(DifferentialFunction function) {
if (!ops.containsKey(function.getOwnName())) if (!ops.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
List<String> inputs = ops.get(function.getOwnName()).getInputsToOp(); List<String> inputs = ops.get(function.getOwnName()).getInputsToOp();
@ -739,7 +784,7 @@ public class SameDiff extends SDBaseOps {
* @param function the function to get the outputs for * @param function the function to get the outputs for
* @return the outputs ids for a given function * @return the outputs ids for a given function
*/ */
public String[] getOutputsForFunction(DifferentialFunction function) { public String[] getOutputsForOp(DifferentialFunction function) {
if (!ops.containsKey(function.getOwnName())) if (!ops.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
List<String> outputs = ops.get(function.getOwnName()).getOutputsOfOp(); List<String> outputs = ops.get(function.getOwnName()).getOutputsOfOp();
@ -753,8 +798,8 @@ public class SameDiff extends SDBaseOps {
* @param function the function reference to get the output variable(s) for * @param function the function reference to get the output variable(s) for
* @return the output variables for the given function * @return the output variables for the given function
*/ */
public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) { public SDVariable[] getOutputVariablesForOp(DifferentialFunction function) {
val inputs = getOutputsForFunction(function); val inputs = getOutputsForOp(function);
if (inputs == null) { if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function); throw new ND4JIllegalStateException("No inputs found for function " + function);
} }
@ -774,8 +819,8 @@ public class SameDiff extends SDBaseOps {
* @param function the function reference to get the input variable(s) for * @param function the function reference to get the input variable(s) for
* @return the input variables for the given function * @return the input variables for the given function
*/ */
public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) { public SDVariable[] getInputVariablesForOp(DifferentialFunction function) {
val inputs = getInputsForFunction(function); val inputs = getInputsForOp(function);
if (inputs == null) { if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function); throw new ND4JIllegalStateException("No inputs found for function " + function);
} }
@ -792,6 +837,10 @@ public class SameDiff extends SDBaseOps {
} }
/**
* Set the stored {@link INDArray} for a variable. Only works if the variable is of type
* {@link VariableType#CONSTANT}, {@link VariableType#PLACEHOLDER}, or {@link VariableType#VARIABLE}.
*/
public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) { public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) {
Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName); Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName);
@ -830,6 +879,9 @@ public class SameDiff extends SDBaseOps {
return variableNameToShape.get(varName); return variableNameToShape.get(varName);
} }
/**
* See {@link #getShapeForVarName(String)}, but returns the shape descriptor.
*/
public LongShapeDescriptor getShapeDescriptorForVarName(String varName) { public LongShapeDescriptor getShapeDescriptorForVarName(String varName) {
if (getVariable(varName).getArr() != null) { if (getVariable(varName).getArr() != null) {
return getVariable(varName).getArr().shapeDescriptor(); return getVariable(varName).getArr().shapeDescriptor();
@ -861,6 +913,9 @@ public class SameDiff extends SDBaseOps {
} }
/**
* Sets the shape descriptor for a variable.
*/
public void putShapeForVarName(String varName, LongShapeDescriptor shape) { public void putShapeForVarName(String varName, LongShapeDescriptor shape) {
val v = getVariable(varName); val v = getVariable(varName);
putShapeForVarName(varName, shape.getShape()); putShapeForVarName(varName, shape.getShape());
@ -1559,19 +1614,6 @@ public class SameDiff extends SDBaseOps {
} }
/**
* Get the differential function (if any) that this variable is the output for
*
* @param variableName Name of the variable
* @return The differential function that this variable is an output of, or null if it is not the output of a function
*/
public DifferentialFunction getVariableOutputFunction(String variableName) {
Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
if (variables.get(variableName).getOutputOfOp() == null)
return null;
return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
}
/** /**
* Returns true if this function already has defined arguments * Returns true if this function already has defined arguments
@ -1628,7 +1670,7 @@ public class SameDiff extends SDBaseOps {
* *
* @return Array of differential functions * @return Array of differential functions
*/ */
public DifferentialFunction[] functions() { public DifferentialFunction[] ops() {
List<DifferentialFunction> out = new ArrayList<>(ops.size()); List<DifferentialFunction> out = new ArrayList<>(ops.size());
for (SameDiffOp op : ops.values()) { for (SameDiffOp op : ops.values()) {
out.add(op.getOp()); out.add(op.getOp());
@ -3143,10 +3185,18 @@ public class SameDiff extends SDBaseOps {
placeholders, batch, requiredActivations, activeListeners, at); placeholders, batch, requiredActivations, activeListeners, at);
} }
/**
* See {@link #one(String, DataType, int...)}.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable one(String name, int... shape) { public SDVariable one(String name, int... shape) {
return one(name, Nd4j.defaultFloatingPointType(), shape); return one(name, Nd4j.defaultFloatingPointType(), shape);
} }
/**
* See {@link #one(String, DataType, long...)}.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable one(String name, long... shape) { public SDVariable one(String name, long... shape) {
return one(name, Nd4j.defaultFloatingPointType(), shape); return one(name, Nd4j.defaultFloatingPointType(), shape);
} }
@ -3174,11 +3224,18 @@ public class SameDiff extends SDBaseOps {
return var(name, new ConstantInitScheme('f', 1.0), dataType, shape); return var(name, new ConstantInitScheme('f', 1.0), dataType, shape);
} }
/**
* See {@link #zero(String, DataType, long...)}.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable zero(String name, long... shape) { public SDVariable zero(String name, long... shape) {
return zero(name, Nd4j.defaultFloatingPointType(), shape); return zero(name, Nd4j.defaultFloatingPointType(), shape);
} }
/**
* See {@link #zero(String, DataType, int...)}.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable zero(String name, int... shape) { public SDVariable zero(String name, int... shape) {
return zero(name, Nd4j.defaultFloatingPointType(), shape); return zero(name, Nd4j.defaultFloatingPointType(), shape);
} }
@ -3293,6 +3350,18 @@ public class SameDiff extends SDBaseOps {
} }
//TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API!
/**
* Variable initialization with a specified {@link WeightInitScheme}
* This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details.
*
* @param name the name of the variable
* @param variableType the SameDiff variable type of the variable (e.g. CONSTANT, PLACEHOLDER, etc.)
* @param weightInitScheme the weight initialization scheme
* @param dataType the data type of the variable (float, int, etc)
* @param shape the shape of the array to be created
* @return the created variable
*/
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
@ -3932,7 +4001,7 @@ public class SameDiff extends SDBaseOps {
* @param varName the variable name to remove * @param varName the variable name to remove
* @param function the function to remove the argument from * @param function the function to remove the argument from
*/ */
public void removeArgFromFunction(String varName, DifferentialFunction function) { public void removeArgFromOp(String varName, DifferentialFunction function) {
val args = function.args(); val args = function.args();
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
@ -4324,7 +4393,7 @@ public class SameDiff extends SDBaseOps {
} }
//Update the internal state: outgoing variables for function //Update the internal state: outgoing variables for function
if (getOutputsForFunction(function) == null) if (getOutputsForOp(function) == null)
addOutgoingFor(ret, function); addOutgoingFor(ret, function);
return ret; return ret;
@ -4357,7 +4426,7 @@ public class SameDiff extends SDBaseOps {
//Update the internal state: outgoing variables for function //Update the internal state: outgoing variables for function
if (getOutputsForFunction(function) == null) if (getOutputsForOp(function) == null)
addOutgoingFor(ret, function); addOutgoingFor(ret, function);
return ret; return ret;
@ -4428,7 +4497,9 @@ public class SameDiff extends SDBaseOps {
.build(); .build();
} }
/**
* Create a new TensorArray.
*/
public TensorArray tensorArray(DataType dataType) { public TensorArray tensorArray(DataType dataType) {
TensorArray ta = new TensorArray(this, dataType); TensorArray ta = new TensorArray(this, dataType);
SDVariable[] outVars = ta.outputVariables(); SDVariable[] outVars = ta.outputVariables();
@ -4439,7 +4510,6 @@ public class SameDiff extends SDBaseOps {
* @param functionName * @param functionName
* @param with * @param with
*/ */
public SDVariable invokeFunctionOn(String functionName, SameDiff with) { public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
SameDiff instance = sameDiffFunctionInstances.get(functionName); SameDiff instance = sameDiffFunctionInstances.get(functionName);
SDVariable ret = instance.invokeGraphOn(with); SDVariable ret = instance.invokeGraphOn(with);
@ -5746,6 +5816,13 @@ public class SameDiff extends SDBaseOps {
return bufferBuilder.dataBuffer(); return bufferBuilder.dataBuffer();
} }
/**
* See {@link #asFlatGraph(long, ExecutorConfiguration, boolean)}.
*
* Uses the default {@link ExecutorConfiguration} with output mode as
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
* with profiling disabled and gather timings enabled.
*/
public FlatGraph asFlatGraph(boolean includeUpdaterState) { public FlatGraph asFlatGraph(boolean includeUpdaterState) {
return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState)); return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState));
} }
@ -5765,6 +5842,10 @@ public class SameDiff extends SDBaseOps {
* This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and * This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and
* all arrays as a ByteBuffer containing the FlatBuffers format data * all arrays as a ByteBuffer containing the FlatBuffers format data
* *
* Uses the default {@link ExecutorConfiguration} with output mode as
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
* with profiling disabled and gather timings enabled.
*
* @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)
* @return a ByteBuffer holding the exported FlatBuffers representation of the graph * @return a ByteBuffer holding the exported FlatBuffers representation of the graph
*/ */
@ -5870,7 +5951,11 @@ public class SameDiff extends SDBaseOps {
/** /**
* This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later<br> * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later<br>
* This includes the updater state, if applicable * This includes the updater state, if applicable.
*
* Uses the default {@link ExecutorConfiguration} with output mode as
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
* with profiling disabled and gather timings enabled.
* *
* @param file File to save the FlatBuffers serialized graph (including arrays) to * @param file File to save the FlatBuffers serialized graph (including arrays) to
*/ */
@ -5878,6 +5963,13 @@ public class SameDiff extends SDBaseOps {
asFlatFile(file, true); asFlatFile(file, true);
} }
/**
* See {@link #asFlatFile(File, ExecutorConfiguration, boolean)}.
*
* Uses the default {@link ExecutorConfiguration} with output mode as
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
* with profiling disabled and gather timings enabled.
*/
public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException { public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException {
val fb = asFlatBuffers(withUpdaterState); val fb = asFlatBuffers(withUpdaterState);
val offset = fb.position(); val offset = fb.position();
@ -5943,6 +6035,8 @@ public class SameDiff extends SDBaseOps {
* instance from a byte buffers * instance from a byte buffers
* instance. * instance.
* *
* See {@link #fromFlatBuffers(ByteBuffer, boolean)}. Loads updater state (loadUpdaterState is true).
*
* @param bbIn the input byte buffer * @param bbIn the input byte buffer
* @return the created samediff instance * @return the created samediff instance
* @throws IOException * @throws IOException
@ -5951,6 +6045,16 @@ public class SameDiff extends SDBaseOps {
return fromFlatBuffers(bbIn, true); return fromFlatBuffers(bbIn, true);
} }
/**
* Create a {@link SameDiff}
* instance from a byte buffers
* instance.
*
* @param bbIn the input byte buffer
* @param loadUpdaterState If true, load the updater state (Adam etc state). For training, use true. For inference, use false
* @return the created samediff instance
* @throws IOException
*/
public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException { public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException {
FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn); FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn);
@ -6287,7 +6391,7 @@ public class SameDiff extends SDBaseOps {
public String summary() { public String summary() {
Map<String, SDVariable> varMap = variableMap(); Map<String, SDVariable> varMap = variableMap();
DifferentialFunction[] functions = functions(); DifferentialFunction[] functions = ops();
int countVarsWithArrays = 0; int countVarsWithArrays = 0;
@ -6324,7 +6428,7 @@ public class SameDiff extends SDBaseOps {
if (outputOf == null) { if (outputOf == null) {
outputOf = "<none>"; outputOf = "<none>";
} else { } else {
DifferentialFunction d = getFunctionById(outputOf); DifferentialFunction d = getOpById(outputOf);
outputOf = d.getOwnName() + "(" + d.opName() + ")"; outputOf = d.getOwnName() + "(" + d.opName() + ")";
} }
outputOfFn.put(s, outputOf); outputOfFn.put(s, outputOf);
@ -6412,7 +6516,7 @@ public class SameDiff extends SDBaseOps {
for (Map.Entry<String, SameDiff> e : sameDiffFunctionInstances.entrySet()) { for (Map.Entry<String, SameDiff> e : sameDiffFunctionInstances.entrySet()) {
SameDiff sd = e.getValue(); SameDiff sd = e.getValue();
int vars = sd.variableMap().size(); int vars = sd.variableMap().size();
int fns = (sd.functions() == null ? 0 : sd.functions().length); int fns = (sd.ops() == null ? 0 : sd.ops().length);
int defFns = sd.definedFunctionNames().size(); int defFns = sd.definedFunctionNames().size();
sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n"); sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n");
@ -6422,11 +6526,16 @@ public class SameDiff extends SDBaseOps {
return sb.toString(); return sb.toString();
} }
/**
* Calculate data types for the variables in the graph
*/
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes() { public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes() {
return calculateOutputDataTypes(false); return calculateOutputDataTypes(false);
} }
/**
* Calculate data types for the variables in the graph
*/
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(boolean dynamicUpdate) { public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(boolean dynamicUpdate) {
List<String> allVars = new ArrayList<>(variables.keySet()); List<String> allVars = new ArrayList<>(variables.keySet());
DataTypesSession session = new DataTypesSession(this, dynamicUpdate); DataTypesSession session = new DataTypesSession(this, dynamicUpdate);

View File

@ -325,7 +325,7 @@ public abstract class AbstractSession<T, O> {
} }
} else if (sameDiff.getVariableOutputFunction(varToExec.getVariable()) != null) { } else if (sameDiff.getVariableOutputOp(varToExec.getVariable()) != null) {
//Variable is the output of an op -> execute op //Variable is the output of an op -> execute op
String opName = sameDiff.getVariables().get(varToExec.getVariable()).getOutputOfOp(); String opName = sameDiff.getVariables().get(varToExec.getVariable()).getOutputOfOp();
@ -336,7 +336,7 @@ public abstract class AbstractSession<T, O> {
//Post execution: work out what is now available for exec //Post execution: work out what is now available for exec
String[] opOutputVarNames = sameDiff.getFunctionById(opName).outputVariablesNames(); String[] opOutputVarNames = sameDiff.getOpById(opName).outputVariablesNames();
Preconditions.checkState(opOutputValues.length == opOutputVarNames.length, "Unexpected number of outputs from executed op %s:" + Preconditions.checkState(opOutputValues.length == opOutputVarNames.length, "Unexpected number of outputs from executed op %s:" +
" got %s outputs when %s outputs were expected (%s)", parameterizedOp.getClass().getSimpleName(), opOutputValues.length, " got %s outputs when %s outputs were expected (%s)", parameterizedOp.getClass().getSimpleName(), opOutputValues.length,
@ -423,10 +423,10 @@ public abstract class AbstractSession<T, O> {
//Note subgraph initially should include placeholders and constants //Note subgraph initially should include placeholders and constants
while (!processingQueue.isEmpty()) { while (!processingQueue.isEmpty()) {
String varName = processingQueue.remove(); String varName = processingQueue.remove();
String opName = (sameDiff.getVariableOutputFunction(varName) == null ? null : sameDiff.getVariableOutputFunction(varName).getOwnName()); String opName = (sameDiff.getVariableOutputOp(varName) == null ? null : sameDiff.getVariableOutputOp(varName).getOwnName());
if (!subgraph.contains(varName)) { if (!subgraph.contains(varName)) {
String[] opInputs = opName == null ? null : sameDiff.getInputsForFunction(sameDiff.getFunctionById(opName)); String[] opInputs = opName == null ? null : sameDiff.getInputsForOp(sameDiff.getOpById(opName));
List<String> controlDeps = sameDiff.getVariables().get(varName).getControlDeps(); List<String> controlDeps = sameDiff.getVariables().get(varName).getControlDeps();
int numInputs = (opInputs == null ? 0 : opInputs.length); int numInputs = (opInputs == null ? 0 : opInputs.length);
if (controlDeps != null) { if (controlDeps != null) {
@ -457,7 +457,7 @@ public abstract class AbstractSession<T, O> {
if (opName != null) { if (opName != null) {
//To execute op - and hence get this variable: need inputs to that op //To execute op - and hence get this variable: need inputs to that op
String[] inputs = sameDiff.getInputsForFunction(sameDiff.getFunctionById(opName)); String[] inputs = sameDiff.getInputsForOp(sameDiff.getOpById(opName));
for (String s2 : inputs) { for (String s2 : inputs) {
if (!subgraph.contains(s2)) { if (!subgraph.contains(s2)) {
processingQueue.add(s2); processingQueue.add(s2);
@ -501,7 +501,7 @@ public abstract class AbstractSession<T, O> {
if (inputForOps != null) { if (inputForOps != null) {
for (String opName : inputForOps) { for (String opName : inputForOps) {
DifferentialFunction fn = sameDiff.getFunctionById(opName); DifferentialFunction fn = sameDiff.getOpById(opName);
if (fn instanceof Merge) { if (fn instanceof Merge) {
//Merge op: available for execution when *any* of its inputs are available. But only mark it for exec once... //Merge op: available for execution when *any* of its inputs are available. But only mark it for exec once...
List<String> opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp(); List<String> opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp();
@ -888,7 +888,7 @@ public abstract class AbstractSession<T, O> {
//Mark that outVar needs this specific executedVar (i.e., specific frame/iteration) //Mark that outVar needs this specific executedVar (i.e., specific frame/iteration)
//However, in the case of enter nodes, they are available for ALL iterations (used in loop conditions, for example) //However, in the case of enter nodes, they are available for ALL iterations (used in loop conditions, for example)
Variable v = sameDiff.getVariables().get(inputVar.getVariable()); Variable v = sameDiff.getVariables().get(inputVar.getVariable());
boolean isEnter = sameDiff.getVariableOutputFunction(v.getVariable().getVarName()) instanceof Enter; boolean isEnter = sameDiff.getVariableOutputOp(v.getVariable().getVarName()) instanceof Enter;
if(isEnter){ if(isEnter){
VarId iter0 = forVariable; VarId iter0 = forVariable;

View File

@ -59,7 +59,7 @@ public class DataTypesSession extends AbstractSession<DataType, DataTypesSession
@Override @Override
public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) { public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) {
DifferentialFunction df = sameDiff.getFunctionById(opName); DifferentialFunction df = sameDiff.getOpById(opName);
List<DataType> inputDataTypes = new ArrayList<>(); List<DataType> inputDataTypes = new ArrayList<>();
for(SDVariable v : df.args()){ for(SDVariable v : df.args()){
DataType dt = v.dataType(); DataType dt = v.dataType();

View File

@ -16,7 +16,6 @@
package org.nd4j.autodiff.samediff.internal; package org.nd4j.autodiff.samediff.internal;
import com.google.common.collect.ImmutableMap;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
@ -121,12 +120,12 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
if(listeners != null && listeners.size() > 0){ if(listeners != null && listeners.size() > 0){
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
ImmutableMap.Builder<String, INDArray> namedOutsBuilder = ImmutableMap.builder(); Map<String, INDArray> namedOutsBuilder = new HashMap<>();
for(int i = 0 ; i < out.length ; i++) for(int i = 0 ; i < out.length ; i++)
namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]); namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]);
Map<String, INDArray> namedOuts = namedOutsBuilder.build(); Map<String, INDArray> namedOuts = Collections.unmodifiableMap(namedOutsBuilder);
for(Listener l : listeners){ for(Listener l : listeners){
if(l.isActive(at.operation())) { if(l.isActive(at.operation())) {
@ -223,7 +222,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
//Merge avairable for forward pass when any of its inputs are available. When multiple are available, behaviour //Merge avairable for forward pass when any of its inputs are available. When multiple are available, behaviour
// is undefined // is undefined
Merge m = (Merge) op; Merge m = (Merge) op;
String[] in = sameDiff.getInputsForFunction(op); String[] in = sameDiff.getInputsForOp(op);
for (String s : in) { for (String s : in) {
VarId vid = newVarId(s, outputFrameIter); VarId vid = newVarId(s, outputFrameIter);
if (nodeOutputs.containsKey(vid)) { if (nodeOutputs.containsKey(vid)) {
@ -275,10 +274,10 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName()); Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName());
while(sameDiff.getVariableOutputFunction(inTensorArray.getVarName()) instanceof Enter){ while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead
//TODO also TensorArrayWrite, scatter, etc?? //TODO also TensorArrayWrite, scatter, etc??
inTensorArray = sameDiff.getVariableOutputFunction(inTensorArray.getVarName()).arg(); inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
v = newVarId(inTensorArray.getVarName(), v.getParentFrame()); v = newVarId(inTensorArray.getVarName(), v.getParentFrame());
} }
@ -300,10 +299,10 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName()); Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName());
while(sameDiff.getVariableOutputFunction(inTensorArray.getVarName()) instanceof Enter){ while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite
//TODO also TensorArrayScatter, etc?? //TODO also TensorArrayScatter, etc??
inTensorArray = sameDiff.getVariableOutputFunction(inTensorArray.getVarName()).arg(); inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame()); tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame());
} }
@ -405,7 +404,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
//Input 2: The values to scatter //Input 2: The values to scatter
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
TensorArray ta = (TensorArray) sameDiff.getVariableOutputFunction(inTensorArray.getVarName()); TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName());
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
if(tArr == null && allIterInputs != null){ if(tArr == null && allIterInputs != null){
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
@ -526,7 +525,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
public DifferentialFunction getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs, public DifferentialFunction getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, Map<String,INDArray> placeholderValues) { Set<String> constAndPhInputs, Map<String,INDArray> placeholderValues) {
DifferentialFunction df = sameDiff.getFunctionById(opName); DifferentialFunction df = sameDiff.getOpById(opName);
//TODO We should clone these ops - probably - as we don't want them shared between threads/sessions! //TODO We should clone these ops - probably - as we don't want them shared between threads/sessions!
//But let's only clone them *once* and cache in inference session - not on every exec //But let's only clone them *once* and cache in inference session - not on every exec

View File

@ -18,7 +18,6 @@ package org.nd4j.autodiff.samediff.ops;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import lombok.NonNull; import lombok.NonNull;
@ -27,7 +26,6 @@ import org.nd4j.autodiff.samediff.ArgumentInterceptor;
import org.nd4j.autodiff.samediff.NameScope; import org.nd4j.autodiff.samediff.NameScope;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.autodiff.samediff.SameDiffLambda; import org.nd4j.autodiff.samediff.SameDiffLambda;
import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda; import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda;
import org.nd4j.autodiff.samediff.SameDiffSingleLambda; import org.nd4j.autodiff.samediff.SameDiffSingleLambda;
@ -3377,7 +3375,7 @@ public abstract class SDBaseOps {
for(SameDiffOp op : sd().getOpsInScope(ifScope)) { for(SameDiffOp op : sd().getOpsInScope(ifScope)) {
for(String in : op.getInputsToOp()){ for(String in : op.getInputsToOp()){
sd().removeArgFromFunction(in, op.getOp()); sd().removeArgFromOp(in, op.getOp());
} }
sd().getOps().remove(op.getName()); sd().getOps().remove(op.getName());
} }

View File

@ -385,6 +385,29 @@ public class SDCNN extends SDOps {
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* 3D CNN deconvolution operation with or without optional bias
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration
*/
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
return deconv3d(null, input, weights, bias, config);
}
/**
* 3D CNN deconvolution operation with no bias
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param config Configuration
*/
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
return deconv3d(input, weights, null, config);
}
/** /**
* Convolution 2d layer batch to space operation on 4d input.<br> * Convolution 2d layer batch to space operation on 4d input.<br>
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br> * Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>

View File

@ -199,7 +199,7 @@ public class LegacyOpMapper {
case 25: case 25:
return Or.class; return Or.class;
case 26: case 26:
return OldAtan2Op.class; throw new UnsupportedOperationException("OldATan2 (op number " + opNum + ") is no longer supported.");
case 27: case 27:
return LogicalOr.class; return LogicalOr.class;
case 28: case 28:
@ -243,7 +243,7 @@ public class LegacyOpMapper {
case 18: case 18:
return Floor.class; return Floor.class;
case 20: case 20:
return OldReverse.class; throw new UnsupportedOperationException("OldReverse (op number " + opNum + ") is no longer supported.");
default: default:
throw new UnsupportedOperationException("No known transform same op for op number: " + opNum); throw new UnsupportedOperationException("No known transform same op for op number: " + opNum);
} }
@ -581,19 +581,19 @@ public class LegacyOpMapper {
public static Class<?> pairwiseOpClass(int opNum){ public static Class<?> pairwiseOpClass(int opNum){
switch (opNum){ switch (opNum){
case 0: case 0:
return OldAddOp.class; throw new UnsupportedOperationException("OldFModOp (op number " + opNum + ") is no longer supported.");
case 1: case 1:
return CopyOp.class; return CopyOp.class;
case 2: case 2:
return OldDivOp.class; throw new UnsupportedOperationException("OldDivOp (op number " + opNum + ") is no longer supported.");
case 3: case 3:
return OldEqualTo.class; throw new UnsupportedOperationException("OldEqualTo (op number " + opNum + ") is no longer supported.");
case 4: case 4:
return OldGreaterThan.class; throw new UnsupportedOperationException("OldGreaterThan (op number " + opNum + ") is no longer supported.");
case 5: case 5:
return OldLessThan.class; throw new UnsupportedOperationException("OldLessThan (op number " + opNum + ") is no longer supported.");
case 6: case 6:
return OldMulOp.class; throw new UnsupportedOperationException("OldMulOp (op number " + opNum + ") is no longer supported.");
case 7: case 7:
return Pow.class; return Pow.class;
case 8: case 8:
@ -603,15 +603,15 @@ public class LegacyOpMapper {
case 10: case 10:
return Eps.class; return Eps.class;
case 11: case 11:
return OldGreaterThanOrEqual.class; throw new UnsupportedOperationException("OldGreaterThanOrEqual (op number " + opNum + ") is no longer supported.");
case 12: case 12:
return OldLessThanOrEqual.class; throw new UnsupportedOperationException("OldLessThanOrEqual (op number " + opNum + ") is no longer supported.");
case 13: case 13:
return OldMax.class; throw new UnsupportedOperationException("OldMax (op number " + opNum + ") is no longer supported.");
case 14: case 14:
return OldMin.class; throw new UnsupportedOperationException("OldMin (op number " + opNum + ") is no longer supported.");
case 15: case 15:
return OldNotEqualTo.class; throw new UnsupportedOperationException("OldNotEqualTo (op number " + opNum + ") is no longer supported.");
case 16: case 16:
return Set.class; return Set.class;
case 17: case 17:
@ -631,11 +631,11 @@ public class LegacyOpMapper {
case 59: case 59:
return RemainderOp.class; return RemainderOp.class;
case 60: case 60:
return OldFModOp.class; throw new UnsupportedOperationException("OldFModOp (op number " + opNum + ") is no longer supported.");
case 69: case 69:
return OldAtan2Op.class; throw new UnsupportedOperationException("OldATan2 (op number " + opNum + ") is no longer supported.");
case 20: case 20:
return OldFloorDivOp.class; throw new UnsupportedOperationException("OldFloorDivOp (op number " + opNum + ") is no longer supported.");
case 26: case 26:
return RelativeError.class; return RelativeError.class;
case 27: case 27:

View File

@ -78,7 +78,7 @@ public class GraphTransformUtil {
if (oldInputsForOps != null) { if (oldInputsForOps != null) {
List<String> newInputsForOps = new ArrayList<>(); List<String> newInputsForOps = new ArrayList<>();
for (String s : oldInputsForOps) { for (String s : oldInputsForOps) {
DifferentialFunction df = sd.getFunctionById(s); DifferentialFunction df = sd.getOpById(s);
if (!allSubGraphFns.contains(df)) { if (!allSubGraphFns.contains(df)) {
newInputsForOps.add(s); newInputsForOps.add(s);
} }
@ -141,7 +141,7 @@ public class GraphTransformUtil {
// (1) variable is (was) input to op that has been removed - just remove from list // (1) variable is (was) input to op that has been removed - just remove from list
// (2) variable is now connected directly as an output: (A->B->C) becomes (A->C) // (2) variable is now connected directly as an output: (A->B->C) becomes (A->C)
// For the latter case, this // For the latter case, this
DifferentialFunction df = sd.getFunctionById(opName); DifferentialFunction df = sd.getOpById(opName);
if (allSubGraphFns.contains(df)) { if (allSubGraphFns.contains(df)) {
newInputsForOp.remove(opName); newInputsForOp.remove(opName);
} }
@ -178,7 +178,7 @@ public class GraphTransformUtil {
*/ */
public static List<SubGraph> getSubgraphsMatching(SameDiff sd, SubGraphPredicate p) { public static List<SubGraph> getSubgraphsMatching(SameDiff sd, SubGraphPredicate p) {
List<SubGraph> out = new ArrayList<>(); List<SubGraph> out = new ArrayList<>();
for (DifferentialFunction df : sd.functions()) { for (DifferentialFunction df : sd.ops()) {
if (p.matches(sd, df)) { if (p.matches(sd, df)) {
SubGraph sg = p.getSubGraph(sd, df); SubGraph sg = p.getSubGraph(sd, df);
out.add(sg); out.add(sg);

View File

@ -20,7 +20,6 @@ import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.lang3.builder.Diff;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -68,7 +67,7 @@ public class SubGraph {
boolean allInSubgraph = true; boolean allInSubgraph = true;
if(inputsFor != null){ if(inputsFor != null){
for(String opOwnName : inputsFor) { for(String opOwnName : inputsFor) {
if (!inSubgraph(sameDiff.getFunctionById(opOwnName))){ if (!inSubgraph(sameDiff.getOpById(opOwnName))){
allInSubgraph = false; allInSubgraph = false;
break; break;
} }

View File

@ -77,7 +77,7 @@ public class SubGraphPredicate extends OpPredicate {
} }
SDVariable in = inputs[inNum]; SDVariable in = inputs[inNum];
DifferentialFunction df = sameDiff.getVariableOutputFunction(in.getVarName()); DifferentialFunction df = sameDiff.getVariableOutputOp(in.getVarName());
if (df == null || !e.getValue().matches(sameDiff, df)) { if (df == null || !e.getValue().matches(sameDiff, df)) {
return false; return false;
} }
@ -103,7 +103,7 @@ public class SubGraphPredicate extends OpPredicate {
for(Map.Entry<Integer,OpPredicate> entry : opInputSubgraphPredicates.entrySet()){ for(Map.Entry<Integer,OpPredicate> entry : opInputSubgraphPredicates.entrySet()){
OpPredicate p2 = entry.getValue(); OpPredicate p2 = entry.getValue();
SDVariable arg = rootFn.arg(entry.getKey()); SDVariable arg = rootFn.arg(entry.getKey());
DifferentialFunction df = sd.getVariableOutputFunction(arg.getVarName()); DifferentialFunction df = sd.getVariableOutputOp(arg.getVarName());
if(df != null){ if(df != null){
childNodes.add(df); childNodes.add(df);

View File

@ -28,7 +28,6 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener; import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -106,7 +105,7 @@ public class GradCheckUtil {
} }
Set<String> fnOutputs = new HashSet<>(); Set<String> fnOutputs = new HashSet<>();
for(DifferentialFunction f : sd.functions()){ for(DifferentialFunction f : sd.ops()){
for(SDVariable s : f.outputVariables()){ for(SDVariable s : f.outputVariables()){
fnOutputs.add(s.getVarName()); fnOutputs.add(s.getVarName());
} }
@ -593,7 +592,7 @@ public class GradCheckUtil {
4. Gradient function: should contain all of the existing functions, and more 4. Gradient function: should contain all of the existing functions, and more
*/ */
DifferentialFunction[] dfs = sd.functions(); DifferentialFunction[] dfs = sd.ops();
List<SDVariable> vars = sd.variables(); List<SDVariable> vars = sd.variables();
Set<String> varSetStr = new HashSet<>(); Set<String> varSetStr = new HashSet<>();
@ -661,7 +660,7 @@ public class GradCheckUtil {
//Check that all original functions are present in the gradient function //Check that all original functions are present in the gradient function
for(DifferentialFunction dfOrig : dfs){ for(DifferentialFunction dfOrig : dfs){
Preconditions.checkNotNull(gradFn.getFunctionById(dfOrig.getOwnName()), "DifferentialFunction " + dfOrig.getOwnName() Preconditions.checkNotNull(gradFn.getOpById(dfOrig.getOwnName()), "DifferentialFunction " + dfOrig.getOwnName()
+ " from original SameDiff instance not present in grad fn"); + " from original SameDiff instance not present in grad fn");
} }
} }

View File

@ -79,7 +79,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
@ -94,7 +93,6 @@ import org.nd4j.linalg.api.ops.random.impl.Linspace;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.tensorflow.framework.OpDef; import org.tensorflow.framework.OpDef;
@ -464,7 +462,7 @@ public class OpValidation {
//i.e., don't double count if a SameDiff instance has multiple copies of the same op type //i.e., don't double count if a SameDiff instance has multiple copies of the same op type
//Collect coverage information for backprop: //Collect coverage information for backprop:
DifferentialFunction[] functions = sd.functions(); DifferentialFunction[] functions = sd.ops();
Set<Class> backpropSeen = new HashSet<>(); Set<Class> backpropSeen = new HashSet<>();
for (DifferentialFunction df : functions) { for (DifferentialFunction df : functions) {
backpropSeen.add(df.getClass()); backpropSeen.add(df.getClass());
@ -481,7 +479,7 @@ public class OpValidation {
if (testCase.fwdTestFns() != null) { if (testCase.fwdTestFns() != null) {
for (String s : testCase.fwdTestFns().keySet()) { for (String s : testCase.fwdTestFns().keySet()) {
//Determine the differential function that this variable is the output of, if any //Determine the differential function that this variable is the output of, if any
DifferentialFunction df = sd.getVariableOutputFunction(s); DifferentialFunction df = sd.getVariableOutputOp(s);
if (df != null) { if (df != null) {
if (seen == null) if (seen == null)
seen = new HashSet<>(); seen = new HashSet<>();

View File

@ -31,7 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -708,8 +708,8 @@ public class ROC extends BaseEvaluation<ROC> {
itp = isTruePositive; itp = isTruePositive;
ifp = isFalsePositive; ifp = isFalsePositive;
} else { } else {
isTruePositive = Nd4j.getExecutioner().exec(new OldMulOp(predictedClass1, positiveActualClassColumn, itp)); isTruePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, positiveActualClassColumn, itp))[0];
isFalsePositive = Nd4j.getExecutioner().exec(new OldMulOp(predictedClass1, negativeActualClassColumn, ifp)); isFalsePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, negativeActualClassColumn, ifp))[0];
} }
//Counts for this batch: //Counts for this batch:

View File

@ -68,17 +68,6 @@ public class DifferentialFunctionClassHolder {
add("sameDiff"); add("sameDiff");
add("ownName"); add("ownName");
}}; }};
private static final Set<String> classesWithConfig = new LinkedHashSet<String>(){{
add(AvgPooling2D.class.getName());
add(Conv2D.class.getName());
add(Conv3D.class.getName());
add(LocalResponseNormalization.class.getName());
add(MaxPooling2D.class.getName());
add(Pooling2D.class.getName());
add(Pooling3D.class.getName());
add(DepthwiseConv2D.class.getName());
add(DeConv2DTF.class.getName());
}};
//When determining fields/properties, where should we terminate the search? //When determining fields/properties, where should we terminate the search?
//We don't wan to include every single field from every single superclass //We don't wan to include every single field from every single superclass
private static final Set<Class> classesToIgnore = new HashSet<>(Arrays.<Class>asList( private static final Set<Class> classesToIgnore = new HashSet<>(Arrays.<Class>asList(
@ -165,16 +154,37 @@ public class DifferentialFunctionClassHolder {
Map<String, Field> fieldNames = new LinkedHashMap<>(); Map<String, Field> fieldNames = new LinkedHashMap<>();
Class<? extends DifferentialFunction> current = df.getClass(); Class<? extends DifferentialFunction> current = df.getClass();
val fields = new ArrayList<Field>(); val fields = new ArrayList<Field>();
boolean isFirst = true;
while (current.getSuperclass() != null && !classesToIgnore.contains(current.getSuperclass())) { while (current.getSuperclass() != null && !classesToIgnore.contains(current.getSuperclass())) {
if (classesWithConfig.contains(current.getName())) {
val fieldName = "config"; if (df.isConfigProperties() && isFirst) {
val configField = current.getDeclaredField(fieldName); String fieldName = df.configFieldName();
if (configField == null) {
continue; if(fieldName == null)
fieldName = "config";
Field configField = null;
try{
configField = current.getDeclaredField(fieldName);
} catch (NoSuchFieldException e){
Class<?> currentConfig = current.getSuperclass();
// find a config field in superclasses
while(currentConfig.getSuperclass() != null){
try {
configField = currentConfig.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e2){
currentConfig = currentConfig.getSuperclass();
}
}
} }
if(configField == null)
continue;
val configFieldClass = configField.getType(); val configFieldClass = configField.getType();
for (val field : configFieldClass.getDeclaredFields()) { for (val field : configFieldClass.getDeclaredFields()) {
@ -206,6 +216,7 @@ public class DifferentialFunctionClassHolder {
// do something with current's fields // do something with current's fields
current = (Class<? extends DifferentialFunction>) current.getSuperclass(); current = (Class<? extends DifferentialFunction>) current.getSuperclass();
isFirst = false;
} }

View File

@ -347,14 +347,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace.class, org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet.class, org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps.class, org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThan.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThanOrEqual.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThanOrEqual.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin.class,
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldNotEqualTo.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class,
@ -453,15 +445,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldFModOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldFloorDivOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldRDivOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldRSubOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp.class,
@ -493,8 +476,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.same.Max.class, org.nd4j.linalg.api.ops.impl.transforms.same.Max.class,
org.nd4j.linalg.api.ops.impl.transforms.same.Min.class, org.nd4j.linalg.api.ops.impl.transforms.same.Min.class,
org.nd4j.linalg.api.ops.impl.transforms.same.Negative.class, org.nd4j.linalg.api.ops.impl.transforms.same.Negative.class,
org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity.class,
org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse.class,
org.nd4j.linalg.api.ops.impl.transforms.same.OneMinus.class, org.nd4j.linalg.api.ops.impl.transforms.same.OneMinus.class,
org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal.class, org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal.class,
org.nd4j.linalg.api.ops.impl.transforms.same.Round.class, org.nd4j.linalg.api.ops.impl.transforms.same.Round.class,

View File

@ -361,7 +361,7 @@ public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_T
} }
protected void initOutputVariables(SameDiff sd, DifferentialFunction df) { protected void initOutputVariables(SameDiff sd, DifferentialFunction df) {
String[] outNames = sd.getOutputsForFunction(df); String[] outNames = sd.getOutputsForOp(df);
SDVariable[] outVars; SDVariable[] outVars;
if (outNames == null) { if (outNames == null) {
outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true); outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true);

View File

@ -409,7 +409,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
newInstance.setSameDiff(importState.getSameDiff()); newInstance.setSameDiff(importState.getSameDiff());
newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph()); newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph());
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(),newInstance); importState.getSameDiff().putOpForId(newInstance.getOwnName(),newInstance);
//ensure we can track node name to function instance later. //ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance); diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance);
//diff.addVarNameForImport(tfNode.getName()); //diff.addVarNameForImport(tfNode.getName());

View File

@ -16,14 +16,11 @@
package org.nd4j.imports.graphmapper.tf; package org.nd4j.imports.graphmapper.tf;
import com.github.os72.protobuf351.Descriptors;
import com.github.os72.protobuf351.Message; import com.github.os72.protobuf351.Message;
import com.google.common.primitives.Floats; import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.indexer.Bfloat16Indexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -41,19 +38,14 @@ import org.nd4j.imports.graphmapper.OpImportFilter;
import org.nd4j.imports.graphmapper.OpImportOverride; import org.nd4j.imports.graphmapper.OpImportOverride;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper; import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers; import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers;
import org.nd4j.linalg.api.buffer.*;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState; import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.*; import org.tensorflow.framework.*;
import org.tensorflow.framework.DataType; import org.tensorflow.framework.DataType;
import java.io.*; import java.io.*;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*; import java.util.*;
/** /**
@ -661,7 +653,7 @@ public class TFGraphMapper extends BaseGraphMapper<GraphDef,NodeDef,AttrValue,No
newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph()); newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction()); mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance); importState.getSameDiff().putOpForId(newInstance.getOwnName(), newInstance);
//ensure we can track node name to function instance later. //ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance); diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
} catch (Exception e) { } catch (Exception e) {

View File

@ -61,6 +61,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo; import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.same.Negative; import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
@ -1638,7 +1639,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
public INDArray lt(INDArray other) { public INDArray lt(INDArray other) {
validateNumericalArray("less than (lt)", false); validateNumericalArray("less than (lt)", false);
if (Shape.shapeEquals(this.shape(), other.shape())) { if (Shape.shapeEquals(this.shape(), other.shape())) {
return Nd4j.getExecutioner().exec(new OldLessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
} else } else
@ -1655,13 +1656,13 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override @Override
public INDArray neq(INDArray other) { public INDArray neq(INDArray other) {
Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array");
return Nd4j.getExecutioner().exec(new OldNotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
} }
@Override @Override
public INDArray eq(INDArray other) { public INDArray eq(INDArray other) {
if (Shape.shapeEquals(this.shape(), other.shape())) { if (Shape.shapeEquals(this.shape(), other.shape())) {
return Nd4j.getExecutioner().exec(new OldEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
} else } else
@ -1672,7 +1673,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
public INDArray gt(INDArray other) { public INDArray gt(INDArray other) {
validateNumericalArray("greater than (gt)", false); validateNumericalArray("greater than (gt)", false);
if (Shape.shapeEquals(this.shape(), other.shape())) { if (Shape.shapeEquals(this.shape(), other.shape())) {
return Nd4j.getExecutioner().exec(new OldGreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
} else } else
@ -5989,7 +5990,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return result; return result;
} else { } else {
OldFModOp op = new OldFModOp(this, denominator, result); FModOp op = new FModOp(this, denominator, result);
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
return result; return result;
} }
@ -6011,7 +6012,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override @Override
public INDArray fmodi(INDArray denominator) { public INDArray fmodi(INDArray denominator) {
validateNumericalArray("fmodi", false); validateNumericalArray("fmodi", false);
OldFModOp op = new OldFModOp(this, denominator, this); FModOp op = new FModOp(this, denominator, this);
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
return this; return this;
} }

View File

@ -274,7 +274,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
@Override @Override
public SDVariable[] outputVariables(String baseName) { public SDVariable[] outputVariables(String baseName) {
if(zVertexId == null) { if(zVertexId == null) {
val outputNames = sameDiff.getOutputsForFunction(this); val outputNames = sameDiff.getOutputsForOp(this);
//no need to dynamically create if already exists //no need to dynamically create if already exists
if(outputNames != null) { if(outputNames != null) {
zVertexId = sameDiff.getVariable(outputNames[0]).getVarName(); zVertexId = sameDiff.getVariable(outputNames[0]).getVarName();
@ -293,13 +293,13 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr); sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr);
z = inputArr; z = inputArr;
if(sameDiff.getOutputsForFunction(this) == null) if(sameDiff.getOutputsForOp(this) == null)
sameDiff.addOutgoingFor(newVars,this); sameDiff.addOutgoingFor(newVars,this);
return newVars; return newVars;
} }
SDVariable[] newVars = sameDiff.generateOutputVariableForOp(this, baseName, false); SDVariable[] newVars = sameDiff.generateOutputVariableForOp(this, baseName, false);
if (sameDiff.getOutputsForFunction(this) == null) if (sameDiff.getOutputsForOp(this) == null)
sameDiff.addOutgoingFor(newVars, this); sameDiff.addOutgoingFor(newVars, this);
return newVars; return newVars;
} }

View File

@ -25,11 +25,9 @@ import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
@ -208,7 +206,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
@Override @Override
public SDVariable[] outputVariables(String baseName) { public SDVariable[] outputVariables(String baseName) {
if (this.outputVariables == null) { if (this.outputVariables == null) {
val outputNames = sameDiff.getOutputsForFunction(this); val outputNames = sameDiff.getOutputsForOp(this);
//no need to dynamically create if already exists //no need to dynamically create if already exists
if (outputNames != null) { if (outputNames != null) {
outputVariables = new SDVariable[outputNames.length]; outputVariables = new SDVariable[outputNames.length];
@ -233,7 +231,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
} }
outputVariables = newVars; outputVariables = newVars;
if (sameDiff.getOutputsForFunction(this) == null) if (sameDiff.getOutputsForOp(this) == null)
sameDiff.addOutgoingFor(outputVariables, this); sameDiff.addOutgoingFor(outputVariables, this);
return newVars; return newVars;
} }
@ -524,7 +522,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
throw new ND4JIllegalStateException("Op [" + opName() + "] failure for [" + this.getOwnName() + "]: Number of inputs is invalid for execution. " throw new ND4JIllegalStateException("Op [" + opName() + "] failure for [" + this.getOwnName() + "]: Number of inputs is invalid for execution. "
+ numInputArguments() + " were provided but " + descriptor.getNumInputs() + " are required for execution"); + numInputArguments() + " were provided but " + descriptor.getNumInputs() + " are required for execution");
} else { } else {
String[] inputNames = sameDiff.getInputsForFunction(this); String[] inputNames = sameDiff.getInputsForOp(this);
String[] arrayShapes = new String[inputNames.length]; String[] arrayShapes = new String[inputNames.length];
for( int i=0; i<inputNames.length; i++ ){ for( int i=0; i<inputNames.length; i++ ){
INDArray arr = sameDiff.getVariable(inputNames[i]).getArr(); INDArray arr = sameDiff.getVariable(inputNames[i]).getArr();

View File

@ -107,7 +107,7 @@ public class If extends DifferentialFunction implements CustomOp {
SameDiffFunctionDefinition falseBody) { SameDiffFunctionDefinition falseBody) {
this.sameDiff = parent; this.sameDiff = parent;
parent.putFunctionForId(getOwnName(),this); parent.putOpForId(getOwnName(),this);
this.inputVars = inputVars; this.inputVars = inputVars;
this.predicate = predicate; this.predicate = predicate;

View File

@ -136,7 +136,7 @@ public class While extends DifferentialFunction implements CustomOp {
this.trueBody = trueBody; this.trueBody = trueBody;
this.blockName = blockName; this.blockName = blockName;
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1); this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
parent.putFunctionForId(getOwnName(),this); parent.putOpForId(getOwnName(),this);
parent.addArgsFor(inputVars,this); parent.addArgsFor(inputVars,this);
parent.addOutgoingFor(new SDVariable[]{dummyResult},this); parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
@ -457,7 +457,7 @@ public class While extends DifferentialFunction implements CustomOp {
//the output of the condition should always be a singular scalar //the output of the condition should always be a singular scalar
//this is a safe assumption //this is a safe assumption
val conditionVars = scopeCondition.functions(); val conditionVars = scopeCondition.ops();
if(conditionVars.length < 1) { if(conditionVars.length < 1) {
throw new ND4JIllegalArgumentException("No functions found!"); throw new ND4JIllegalArgumentException("No functions found!");
} }

View File

@ -22,7 +22,6 @@ import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -40,7 +39,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -71,7 +69,7 @@ public class Conv1D extends DynamicCustomOp {
this.config = config; this.config = config;
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
addArgs(); addArgs();
sameDiff.putFunctionForId(this.getOwnName(), this); sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputFunctions, this); sameDiff.addArgsFor(inputFunctions, this);
} }
@ -113,12 +111,6 @@ public class Conv1D extends DynamicCustomOp {
return config.toProperties(); return config.toProperties();
} }
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs();
}
@Override @Override
public boolean isConfigProperties() { public boolean isConfigProperties() {
return true; return true;
@ -129,107 +121,6 @@ public class Conv1D extends DynamicCustomOp {
return "config"; return "config";
} }
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
}
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
tfMappings.put("kH", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 2, 0, fields.get("dataFormat")));
tfMappings.put("kW", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 3, 1, fields.get("dataFormat")));
tfMappings.put("sH", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 2, 1, fields.get("dataFormat")));
tfMappings.put("sW", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 3, 2, fields.get("dataFormat")));
tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
onnxMappings.put("kH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("kW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("dH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("dW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("sH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("sW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
onnxMappings.put("isNHWC", new StringEqualsAdapter("NHC"));
ret.put(tensorflowName(), tfMappings);
ret.put(onnxName(), onnxMappings);
return ret;
}
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val strideMapping = PropertyMapping.builder()
.tfAttrName("strides")
.onnxAttrName("strides")
.propertyNames(new String[]{"s"})
.build();
val kernelMapping = PropertyMapping.builder()
.propertyNames(new String[]{"k"})
.tfInputPosition(1)
.shapePosition(0)
.onnxAttrName("kernel_shape")
.build();
val paddingMapping = PropertyMapping.builder()
.onnxAttrName("padding")
.propertyNames(new String[]{"p"})
.build();
val dataFormat = PropertyMapping.builder()
.onnxAttrName("data_format")
.tfAttrName("data_format")
.propertyNames(new String[]{"dataFormat"})
.build();
val nhwc = PropertyMapping.builder()
.onnxAttrName("data_format")
.tfAttrName("data_format")
.propertyNames(new String[]{"isNHWC"})
.build();
val sameMode = PropertyMapping.builder()
.onnxAttrName("auto_pad")
.propertyNames(new String[]{"isSameMode"})
.tfAttrName("padding")
.build();
map.put("s", strideMapping);
map.put("k", kernelMapping);
map.put("p", paddingMapping);
map.put("isSameMode", sameMode);
map.put("dataFormat", dataFormat);
map.put("isNHWC", nhwc);
try {
ret.put(onnxName(), map);
} catch (NoOpNameFoundException e) {
//ignore
}
try {
ret.put(tensorflowName(), map);
} catch (NoOpNameFoundException e) {
//ignore
}
return ret;
}
@Override @Override
public String opName() { public String opName() {
return "conv1d"; return "conv1d";
@ -241,16 +132,6 @@ public class Conv1D extends DynamicCustomOp {
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
} }
@Override
public String tensorflowName() {
return "Conv1D";
}
@Override
public String[] tensorflowNames() {
return new String[]{"Conv1D"};
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length; int n = args().length;

View File

@ -70,7 +70,7 @@ public class Conv2D extends DynamicCustomOp {
config.getSH(), config.getPH(), config.getDW()); config.getSH(), config.getPH(), config.getDW());
addArgs(); addArgs();
if(sameDiff != null) { if(sameDiff != null) {
sameDiff.putFunctionForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
sameDiff.addArgsFor(inputFunctions, this); sameDiff.addArgsFor(inputFunctions, this);
} }
} }

View File

@ -68,7 +68,7 @@ public class DeConv2D extends DynamicCustomOp {
} }
addArgs(); addArgs();
sameDiff.putFunctionForId(this.getOwnName(), this); sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputs, this); sameDiff.addArgsFor(inputs, this);
} }

View File

@ -21,7 +21,6 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -71,7 +70,7 @@ public class DeConv2DTF extends DynamicCustomOp {
} }
addArgs(); addArgs();
sameDiff.putFunctionForId(this.getOwnName(), this); sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputs, this); sameDiff.addArgsFor(inputs, this);
} }

View File

@ -62,7 +62,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
this.config = config; this.config = config;
addArgs(); addArgs();
sameDiff.putFunctionForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
sameDiff.addArgsFor(inputFunctions, this); sameDiff.addArgsFor(inputFunctions, this);
} }

View File

@ -76,6 +76,16 @@ public class LocalResponseNormalization extends DynamicCustomOp {
addIArgument(config.getDepth()); addIArgument(config.getDepth());
} }
@Override
public boolean isConfigProperties() {
return true;
}
@Override
public String configFieldName(){
return "config";
}
@Override @Override
public String opName() { public String opName() {
return "lrn"; return "lrn";

View File

@ -65,7 +65,7 @@ public class TensorMmul extends DynamicCustomOp {
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
this.mMulTranspose = mMulTranspose; this.mMulTranspose = mMulTranspose;
this.axes = dimensions; this.axes = dimensions;
if(!addedEdges && sameDiff.getOutputsForFunction(this) == null) { if(!addedEdges && sameDiff.getOutputsForOp(this) == null) {
addedEdges = true; addedEdges = true;
} }

View File

@ -151,7 +151,7 @@ public class Concat extends DynamicCustomOp {
removeInputArgument(inputArgs[inputArguments().length - 1]); removeInputArgument(inputArgs[inputArguments().length - 1]);
} }
sameDiff.removeArgFromFunction(input,this); sameDiff.removeArgFromOp(input,this);
} }
@Override @Override

View File

@ -72,7 +72,7 @@ public class TensorArrayConcat extends BaseTensorOp {
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){ public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
//Same output type as the TensorArray - which is defined by input 0 //Same output type as the TensorArray - which is defined by input 0
SDVariable tArr = arg(0); SDVariable tArr = arg(0);
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputFunction(tArr.getVarName()); TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName());
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
return Collections.singletonList(dt); return Collections.singletonList(dt);
} }

View File

@ -72,7 +72,7 @@ public class TensorArrayGather extends BaseTensorOp {
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){ public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
//Same output type as the TensorArray - which is defined by input 0 //Same output type as the TensorArray - which is defined by input 0
SDVariable tArr = arg(0); SDVariable tArr = arg(0);
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputFunction(tArr.getVarName()); TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName());
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
return Collections.singletonList(dt); return Collections.singletonList(dt);
} }

View File

@ -72,7 +72,7 @@ public class TensorArrayRead extends BaseTensorOp {
dt = importDataType; dt = importDataType;
} else { } else {
SDVariable tArr = arg(0); SDVariable tArr = arg(0);
DifferentialFunction op = sameDiff.getVariableOutputFunction(tArr.getVarName()); DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.getVarName());
TensorArray t3 = (TensorArray) op; TensorArray t3 = (TensorArray) op;
dt = t3.getTensorArrayDataType(); dt = t3.getTensorArrayDataType();
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms; package org.nd4j.linalg.api.ops.impl.transforms;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -55,6 +56,14 @@ public class Pad extends DynamicCustomOp {
addTArgument(padValue); addTArgument(padValue);
} }
public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){
super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out});
Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType());
this.mode = mode;
addIArgument(mode.ordinal());
addTArgument(padValue);
}
@Override @Override
public String opName(){ public String opName(){
return "pad"; return "pad";

View File

@ -1,95 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.Arrays;
import java.util.List;
/**
* Bit mask over the ndarrays as to whether
* the components are equal or not
*
* @author Adam Gibson
*/
public class OldEqualTo extends BaseTransformBoolOp {
public OldEqualTo(SameDiff sameDiff) {
super(sameDiff);
}
public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, Object[] extraArgs) {
super(sameDiff, i_v1, i_v2, extraArgs);
}
public OldEqualTo(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldEqualTo() {}
public OldEqualTo(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldEqualTo(INDArray x, INDArray y) {
super(x, y, null);
}
@Override
public int opNum() {
return 0;
}
@Override
public String opName() {
return "oldeq";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No Tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
//Equals op: 2 inputs, not continuously differentiable but 0s almost everywhere
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
}
}

View File

@ -1,91 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.Arrays;
import java.util.List;
/**
* Bit mask over the ndarrays as to whether
* the components are greater than or not
*
* @author Adam Gibson
*/
public class OldGreaterThan extends BaseTransformBoolOp {
public OldGreaterThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldGreaterThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldGreaterThan(SameDiff sameDiff) {
super(sameDiff);
}
public OldGreaterThan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldGreaterThan() {}
public OldGreaterThan(INDArray x, INDArray z) {
super(x, z);
}
public OldGreaterThan(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldGreaterThan(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 1;
}
@Override
public String opName() {
return "oldgt";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow name found");
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Arrays.asList(outputVariables()[0]);
}
}

View File

@ -1,85 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.List;
/**
* Bit mask over the ndarrays as to whether
* the components are greater than or equal or not
*
* @author Adam Gibson
*/
public class OldGreaterThanOrEqual extends BaseTransformBoolOp {
public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldGreaterThanOrEqual(SameDiff sameDiff) {
super(sameDiff);
}
public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldGreaterThanOrEqual() {}
public OldGreaterThanOrEqual(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldGreaterThanOrEqual(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 4;
}
@Override
public String opName() {
return "oldgte";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
}

View File

@ -1,91 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.Arrays;
import java.util.List;
/**
* Bit mask over the ndarrays as to whether
* the components are less than or not
*
* @author Adam Gibson
*/
public class OldLessThan extends BaseTransformBoolOp {
public OldLessThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldLessThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldLessThan(SameDiff sameDiff) {
super(sameDiff);
}
public OldLessThan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldLessThan() {}
public OldLessThan(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldLessThan(INDArray x) {
super(x);
}
public OldLessThan(INDArray x, INDArray z) {
super(x, z);
}
@Override
public int opNum() {
return 2;
}
@Override
public String opName() {
return "oldlt";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tf opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Arrays.asList(outputVariables()[0]);
}
}

View File

@ -1,104 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.Arrays;
import java.util.List;
/**
* Bit mask over the ndarrays as to whether
* the components are less than or equal or not
*
* @author Adam Gibson
*/
public class OldLessThanOrEqual extends BaseTransformBoolOp {
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldLessThanOrEqual(SameDiff sameDiff) {
super(sameDiff);
}
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, Object[] extraArgs) {
super(sameDiff, i_v1, i_v2, extraArgs);
}
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, Object[] extraArgs) {
super(sameDiff, i_v, shape, inPlace, extraArgs);
}
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) {
super(sameDiff, i_v, extraArgs);
}
public OldLessThanOrEqual() {}
public OldLessThanOrEqual(INDArray x, INDArray z) {
super(x, z);
}
public OldLessThanOrEqual(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldLessThanOrEqual(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 5;
}
@Override
public String opName() {
return "oldlte";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Arrays.asList(outputVariables()[0]);
}
}

View File

@ -1,88 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.List;
/**
* Max function
*
* @author Adam Gibson
*/
public class OldMax extends BaseTransformSameOp {
public OldMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldMax(SameDiff sameDiff) {
super(sameDiff);
}
public OldMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldMax() {}
public OldMax(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldMax(INDArray x) {
super(x);
}
public OldMax(INDArray ndArray, INDArray dup) {
super(ndArray, dup);
}
@Override
public int opNum() {
return 7;
}
@Override
public String opName() {
return "old_max_transform";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
}

View File

@ -1,88 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.List;
/**
* Min function
*
* @author Adam Gibson
*/
public class OldMin extends BaseTransformSameOp {
public OldMin(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldMin(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldMin(SameDiff sameDiff) {
super(sameDiff);
}
public OldMin(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldMin() {}
public OldMin(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldMin(INDArray x) {
super(x);
}
public OldMin(INDArray x, INDArray z) {
super(x, z);
}
@Override
public int opNum() {
return 8;
}
@Override
public String opName() {
return "old_min_transform";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
}

View File

@ -1,87 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.Arrays;
import java.util.List;
/**
* Not equal to function:
* Bit mask over whether 2 elements are not equal or not
*
* @author Adam Gibson
*/
public class OldNotEqualTo extends BaseTransformBoolOp {
public OldNotEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldNotEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldNotEqualTo() {
}
public OldNotEqualTo(INDArray x) {
super(x);
}
public OldNotEqualTo(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldNotEqualTo(INDArray x, INDArray z) {
super(x, z);
}
@Override
public int opNum() {
return 6;
}
@Override
public String opName() {
return "old_neq";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No op name found");
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return Arrays.asList(f().neg(i_v.get(0)));
}
}

View File

@ -22,11 +22,13 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import org.nd4j.linalg.ops.transforms.Transforms;
/** /**
* Arc Tangent elementwise function * Arc Tangent elementwise function
@ -39,6 +41,15 @@ public class ATan2 extends BaseDynamicTransformOp {
super(sameDiff, new SDVariable[] {y, x} ,false); super(sameDiff, new SDVariable[] {y, x} ,false);
} }
/**
* Note that the order of x and y match {@link java.lang.Math#atan2(double, double)},
* and are reversed when compared to OldATan2.
* See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)}
*/
public ATan2(INDArray x, INDArray y, INDArray z) {
super(new INDArray[]{x, y}, new INDArray[]{ z });
}
public ATan2() {} public ATan2() {}
@Override @Override

View File

@ -46,6 +46,9 @@ public class EqualTo extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public EqualTo(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z});
}
@Override @Override
public String opName() { public String opName() {

View File

@ -47,7 +47,9 @@ public class GreaterThan extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public GreaterThan(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z});
}
@Override @Override
public String opName() { public String opName() {

View File

@ -46,6 +46,10 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public GreaterThanOrEqual(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z});
}
@Override @Override
public int opNum() { public int opNum() {
return 11; return 11;

View File

@ -47,6 +47,10 @@ public class LessThan extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public LessThan(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z});
}
@Override @Override
public String opName() { public String opName() {
return "less"; return "less";

View File

@ -45,6 +45,11 @@ public class LessThanOrEqual extends BaseDynamicTransformOp {
public LessThanOrEqual( INDArray[] inputs, INDArray[] outputs) { public LessThanOrEqual( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs); super(inputs, outputs);
} }
public LessThanOrEqual(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z});
}
@Override @Override
public String opName() { public String opName() {
return "less_equal"; return "less_equal";

View File

@ -46,6 +46,9 @@ public class NotEqualTo extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public NotEqualTo(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z});
}
@Override @Override
public String opName() { public String opName() {

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays; import java.util.Arrays;
@ -38,6 +39,27 @@ public class Reverse extends DynamicCustomOp {
public Reverse() { public Reverse() {
} }
/**
* Inplace reverse. See {@link #Reverse(INDArray, INDArray)}
*/
public Reverse(INDArray x){
this(x, x);
this.inPlace = true;
}
/**
* Reverses whole array for compatibility with OldReverse.
*
* Note that otherwise, passing null or empty dimensions will result in a noop.
*/
public Reverse(INDArray x, INDArray z){
super(new INDArray[]{x}, new INDArray[]{z});
this.dimensions = new int[x.rank()];
for(int i = 0 ; i < this.dimensions.length ; i++)
this.dimensions[i] = i;
addIArgument(dimensions);
}
@Override @Override
public String opName() { public String opName() {
return "reverse"; return "reverse";

View File

@ -52,6 +52,9 @@ public class FModOp extends BaseTransformSameOp {
public FModOp(INDArray x, INDArray z) { public FModOp(INDArray x, INDArray z) {
super(x, z); super(x, z);
} }
public FModOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public FModOp(INDArray x) { public FModOp(INDArray x) {
super(x); super(x);

View File

@ -1,89 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.ArrayList;
import java.util.List;
/**
* @deprecated Use {@link AddOp}
*/
@Deprecated
public class OldAddOp extends BaseTransformAnyOp {
public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldAddOp() {}
public OldAddOp(INDArray x) {
super(x);
}
public OldAddOp(INDArray x, INDArray z) {
super(x, z);
}
public OldAddOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 0;
}
@Override
public String opName() {
return "old_add";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
}

View File

@ -1,86 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.List;
/**
* atan2 operation
*
* @author raver119@gmail.com
*/
public class OldAtan2Op extends BaseTransformAnyOp {
public OldAtan2Op(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldAtan2Op(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldAtan2Op(SameDiff sameDiff) {
super(sameDiff);
}
public OldAtan2Op(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldAtan2Op() {}
public OldAtan2Op(INDArray x, INDArray y) {
super(x, y, x);
}
public OldAtan2Op(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 16;
}
@Override
public String opName() {
return "old_atan2";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx opName found for " + opName());
}
@Override
public String tensorflowName() {
return "ATan2";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
}

View File

@ -1,88 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.ArrayList;
import java.util.List;
/**
* @deprecated Use {@link DivOp}
*/
@Deprecated
public class OldDivOp extends BaseTransformAnyOp {
public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldDivOp() {}
public OldDivOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
public OldDivOp(INDArray x) {
super(x);
}
public OldDivOp(INDArray x, INDArray z) {
super(x, z);
}
@Override
public int opNum() {
return 2;
}
@Override
public String opName() {
return "olddiv";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
}

View File

@ -1,88 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.List;
/**
* Floating point remainder
*
* @author raver119@gmail.com
*/
public class OldFModOp extends BaseTransformAnyOp {
public OldFModOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldFModOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldFModOp(SameDiff sameDiff) {
super(sameDiff);
}
public OldFModOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldFModOp() {}
public OldFModOp(INDArray x) {
super(x);
}
public OldFModOp(INDArray x, INDArray z) {
super(x, z);
}
public OldFModOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 15;
}
@Override
public String opName() {
return "oldfmod";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
}

View File

@ -1,89 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.ArrayList;
import java.util.List;
/**
* Truncated division operation
*
* @author Adam Gibson
*/
public class OldFloorDivOp extends BaseTransformAnyOp {
public OldFloorDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldFloorDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldFloorDivOp() {}
public OldFloorDivOp(INDArray x) {
super(x);
}
public OldFloorDivOp(INDArray x, INDArray z) {
super(x, z);
}
public OldFloorDivOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 18;
}
@Override
public String opName() {
return "oldfloordiv";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
}

View File

@ -1,91 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.ArrayList;
import java.util.List;
/**
* @deprecated Use {@link MulOp}
*/
@Deprecated
public class OldMulOp extends BaseTransformAnyOp {
public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldMulOp() {}
public OldMulOp(INDArray x) {
super(x);
}
public OldMulOp(INDArray x, INDArray z) {
super(x, z);
}
public OldMulOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 3;
}
@Override
public String opName() {
return "oldmul";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable g = sameDiff.setupFunction(i_v.get(0));
SDVariable gradWrtX = f().mul(g,rarg());
SDVariable gradWrtY = f().mul(g,larg());
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
}

View File

@ -1,87 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.ArrayList;
import java.util.List;
/**
* @deprecated Use {@link RDivOp}
*/
@Deprecated
public class OldRDivOp extends BaseTransformAnyOp {
public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldRDivOp() {}
public OldRDivOp(INDArray x) {
super(x);
}
public OldRDivOp(INDArray x, INDArray z) {
super(x, z);
}
public OldRDivOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 11;
}
@Override
public String opName() {
return "oldrdiv";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0),larg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(rarg(),larg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
}

View File

@ -1,87 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import java.util.ArrayList;
import java.util.List;
/**
* @deprecated Use {@link RSubOp}
*/
@Deprecated
public class OldRSubOp extends BaseTransformAnyOp {
public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldRSubOp() {}
public OldRSubOp(INDArray x) {
super(x);
}
public OldRSubOp(INDArray x, INDArray z) {
super(x, z);
}
public OldRSubOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 5;
}
@Override
public String opName() {
return "old_rsub";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
}

View File

@ -1,89 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.ArrayList;
import java.util.List;
/**
* @deprecated Use {@link SubOp}
*/
@Deprecated
public class OldSubOp extends BaseTransformAnyOp {
public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldSubOp() {}
public OldSubOp(INDArray x) {
super(x);
}
public OldSubOp(INDArray x, INDArray z) {
super(x, z);
}
public OldSubOp(INDArray x, INDArray y, INDArray z) {
super(x, y, z);
}
@Override
public int opNum() {
return 6;
}
@Override
public String opName() {
return "old_sub";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
}

View File

@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
@ -40,6 +41,10 @@ public class Identity extends BaseDynamicTransformOp {
super(sd, new SDVariable[]{input}, false); super(sd, new SDVariable[]{input}, false);
} }
public Identity(INDArray x, INDArray z){
super(new INDArray[]{x}, new INDArray[]{z});
}
public Identity(){ } public Identity(){ }
@Override @Override

View File

@ -1,77 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.same;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
/**
* Identity function
*
* @author Adam Gibson
*/
public class OldIdentity extends BaseTransformSameOp {
public OldIdentity(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldIdentity() {
}
public OldIdentity(INDArray x, INDArray z) {
super(x, z);
}
public OldIdentity(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 15;
}
@Override
public String opName() {
return "old_identity";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("This op does not work with onnx.");
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("This op does not work with tensorflow.");
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return i_v;
}
}

View File

@ -1,74 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.same;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.Arrays;
import java.util.List;
/**
* OldReverse op
*/
public class OldReverse extends BaseTransformSameOp {
public OldReverse(SameDiff sameDiff, SDVariable i_v, int... dimensions) {
super(sameDiff, i_v, false);
this.dimensions = dimensions;
}
public OldReverse() {
}
public OldReverse(INDArray x, INDArray z) {
super(x, z);
}
public OldReverse(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 20;
}
@Override
public String opName() {
return "old_reverse";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
SDVariable ret = f().reverse(f1.get(0), dimensions);
return Arrays.asList(ret);
}
}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.convolution; package org.nd4j.linalg.convolution;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.Pad.Mode;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -129,8 +130,7 @@ public class OldConvolution {
long w = img.size(3); long w = img.size(3);
long outHeight = outSize(h, kh, sy, ph, coverAll); long outHeight = outSize(h, kh, sy, ph, coverAll);
long outWidth = outSize(w, kw, sx, pw, coverAll); long outWidth = outSize(w, kw, sx, pw, coverAll);
INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, Mode.CONSTANT, pval);
Nd4j.PadMode.CONSTANT);
INDArray ret = Nd4j.create(n, c, kh, kw, outHeight, outWidth); INDArray ret = Nd4j.create(n, c, kh, kw, outHeight, outWidth);
for (int i = 0; i < kh; i++) { for (int i = 0; i < kh; i++) {
//offset for the row based on the stride and output height //offset for the row based on the stride and output height

View File

@ -35,7 +35,8 @@ public class OmpNumThreadsAction implements EnvironmentalAction {
val skipper = System.getenv(ND4JEnvironmentVars.ND4J_SKIP_BLAS_THREADS); val skipper = System.getenv(ND4JEnvironmentVars.ND4J_SKIP_BLAS_THREADS);
if (skipper == null) { if (skipper == null) {
// we infer num threads only if skipper undefined // we infer num threads only if skipper undefined
Nd4j.setNumThreads(v); // Nd4j.setNumThreads(v);
// method does not do anything anymore and was removed
} }
} }
} }

View File

@ -20,6 +20,14 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.*; import org.nd4j.linalg.api.ops.impl.broadcast.*;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.*; import org.nd4j.linalg.api.ops.impl.broadcast.bool.*;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.impl.transforms.same.AMax; import org.nd4j.linalg.api.ops.impl.transforms.same.AMax;
@ -42,7 +50,7 @@ public class Broadcast {
public static INDArray add(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray add(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldAddOp(x,y,z)); return Nd4j.getExecutioner().exec(new AddOp(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastAddOp(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastAddOp(x,y,z,dimensions));
@ -66,7 +74,7 @@ public class Broadcast {
public static INDArray div(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray div(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldDivOp(x,y,z)); return Nd4j.getExecutioner().exec(new DivOp(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastDivOp(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastDivOp(x,y,z,dimensions));
@ -78,7 +86,7 @@ public class Broadcast {
public static INDArray eq(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray eq(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldEqualTo(x,y,z)); return Nd4j.getExecutioner().exec(new EqualTo(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastEqualTo(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastEqualTo(x,y,z,dimensions));
} }
@ -89,7 +97,7 @@ public class Broadcast {
public static INDArray gt(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray gt(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldGreaterThan(x,y,z)); return Nd4j.getExecutioner().exec(new GreaterThan(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastGreaterThan(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastGreaterThan(x,y,z,dimensions));
@ -101,7 +109,7 @@ public class Broadcast {
public static INDArray gte(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray gte(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldGreaterThanOrEqual(x,y,z)); return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastGreaterThanOrEqual(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastGreaterThanOrEqual(x,y,z,dimensions));
@ -113,7 +121,7 @@ public class Broadcast {
public static INDArray lt(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray lt(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldLessThan(x,y,z)); return Nd4j.getExecutioner().exec(new LessThan(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastLessThan(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastLessThan(x,y,z,dimensions));
@ -125,7 +133,7 @@ public class Broadcast {
public static INDArray lte(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray lte(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldLessThanOrEqual(x,y,z)); return Nd4j.getExecutioner().exec(new LessThanOrEqual(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastLessThanOrEqual(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastLessThanOrEqual(x,y,z,dimensions));
@ -137,7 +145,7 @@ public class Broadcast {
public static INDArray mul(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray mul(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldMulOp(x,y,z)); return Nd4j.getExecutioner().exec(new MulOp(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastMulOp(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastMulOp(x,y,z,dimensions));
@ -149,7 +157,7 @@ public class Broadcast {
public static INDArray neq(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray neq(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldNotEqualTo(x,y,z)); return Nd4j.getExecutioner().exec(new NotEqualTo(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastNotEqual(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastNotEqual(x,y,z,dimensions));
@ -161,7 +169,7 @@ public class Broadcast {
public static INDArray rdiv(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray rdiv(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldRDivOp(x,y,z)); return Nd4j.getExecutioner().exec(new RDivOp(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastRDivOp(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastRDivOp(x,y,z,dimensions));
@ -173,7 +181,7 @@ public class Broadcast {
public static INDArray rsub(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray rsub(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldSubOp(x,y,z)); return Nd4j.getExecutioner().exec(new SubOp(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastRSubOp(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastRSubOp(x,y,z,dimensions));
@ -185,7 +193,7 @@ public class Broadcast {
public static INDArray sub(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray sub(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldSubOp(x,y,z)); return Nd4j.getExecutioner().exec(new SubOp(x,y,z))[0];
} }
return Nd4j.getExecutioner().exec(new BroadcastSubOp(x,y,z,dimensions)); return Nd4j.getExecutioner().exec(new BroadcastSubOp(x,y,z,dimensions));
@ -197,7 +205,7 @@ public class Broadcast {
public static INDArray max(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray max(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldMax(x,y,z)); return Nd4j.getExecutioner().exec(new Max(x,y,z))[0];
} }
@ -210,7 +218,7 @@ public class Broadcast {
public static INDArray min(INDArray x, INDArray y, INDArray z, int... dimensions) { public static INDArray min(INDArray x, INDArray y, INDArray z, int... dimensions) {
if(dimensions == null || dimensions.length == 0) { if(dimensions == null || dimensions.length == 0) {
validateShapesNoDimCase(x,y,z); validateShapesNoDimCase(x,y,z);
return Nd4j.getExecutioner().exec(new OldMin(x,y,z)); return Nd4j.getExecutioner().exec(new Min(x,y,z))[0];
} }

View File

@ -57,8 +57,10 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Diag; import org.nd4j.linalg.api.ops.impl.shape.Diag;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart; import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
import org.nd4j.linalg.api.ops.impl.shape.Stack; import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.api.ops.impl.transforms.Pad.Mode;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.shape.Tile; import org.nd4j.linalg.api.ops.impl.shape.Tile;
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
import org.nd4j.linalg.api.ops.random.impl.*; import org.nd4j.linalg.api.ops.random.impl.*;
import org.nd4j.linalg.api.rng.DefaultRandom; import org.nd4j.linalg.api.rng.DefaultRandom;
@ -182,93 +184,69 @@ public class Nd4j {
nd4j.initContext(); nd4j.initContext();
} }
public enum PadMode {
CONSTANT, EDGE, LINEAR_RAMP, MAXIMUM, MEAN, MEDIAN, MINIMUM, REFLECT, SYMMETRIC, WRAP
}
/** /**
* See {@link #pad(INDArray, int[][], List, PadMode)} with zero padding. (zeros for constantValues). * See {@link #pad(INDArray, INDArray)}. Uses 0 padding.
*/ */
public static INDArray pad(INDArray toPad, int[][] padWidth, PadMode padMode) { public static INDArray pad(@NonNull INDArray toPad, @NonNull int[][] padWidth){
return pad(toPad, padWidth, ArrayUtil.zerosMatrix(toPad.shape()), padMode); return pad(toPad, Nd4j.createFromArray(padWidth));
} }
/** /**
* Pad the given ndarray to the size along each dimension * See {@link #pad(INDArray, INDArray)}. Uses 0 padding, and uses padWidth for all dimensions.
*/
public static INDArray pad(@NonNull INDArray toPad, @NonNull int... padWidth){
return pad(toPad, padWidth, Mode.CONSTANT, 0);
}
/**
* See {@link #pad(INDArray, INDArray, Pad.Mode, double)} with zero padding (zeros for padValue).
*/
public static INDArray pad(INDArray toPad, INDArray padding) {
return pad(toPad, padding, Mode.CONSTANT, 0);
}
/**
* See {@link #pad(INDArray, INDArray, Mode, double)}.
*/
public static INDArray pad(@NonNull INDArray toPad, @NonNull int[][] padWidth, @NonNull Pad.Mode padMode, double padValue){
return pad(toPad, Nd4j.createFromArray(padWidth), padMode, padValue);
}
/**
* See {@link #pad(INDArray, INDArray, Mode, double)}, uses padWidth for all dimensions.
*/
public static INDArray pad(@NonNull INDArray toPad, @NonNull int[] padWidth, @NonNull Pad.Mode padMode, double padValue){
int[][] pads = new int[toPad.rank()][padWidth.length];
for(int i = 0 ; i < pads.length ; i++){
pads[i] = padWidth;
}
return pad(toPad, pads, padMode, padValue);
}
/**
* Pad the given ndarray to the size along each dimension.
*
* @param toPad the ndarray to pad * @param toPad the ndarray to pad
* @param padWidth the width to pad along each dimension * @param padWidth the width to pad along each dimension
* @param constantValues the values to append for each dimension
* @param padMode the mode to pad in * @param padMode the mode to pad in
* @param padValue the value used during padding. Only used when padMode is {@link Pad.Mode#CONSTANT}.
* @return the padded ndarray * @return the padded ndarray
* based on the specified mode * based on the specified mode
*/ */
public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) { public static INDArray pad(@NonNull INDArray toPad, @NonNull INDArray padWidth, @NonNull Pad.Mode padMode, double padValue) {
if (padMode == PadMode.CONSTANT) {
if (padWidth.length < toPad.rank())
throw new IllegalArgumentException("Please specify a pad width for each dimension");
List<int[]> sizes = new ArrayList<>(); Preconditions.checkArgument(toPad.rank() == padWidth.size(0),
for (int i = 0; i < toPad.rank(); i++) { "Must provide padding values for each dimension. Expected %s pairs for a rank %s array, got %s",
sizes.add(padWidth[i]); toPad.rank(), toPad.rank(), padWidth.size(0));
}
return padImpl(toPad, sizes, constantValues); long[] newShape = new long[toPad.rank()];
for(int i = 0 ; i < newShape.length ; i++){
newShape[i] = toPad.size(i) + padWidth.getRow(i).sumNumber().intValue();
} }
throw new UnsupportedOperationException(); INDArray out = Nd4j.createUninitialized(toPad.dataType(), newShape);
} Pad op = new Pad(toPad, padWidth, out, padMode, padValue);
/** return Nd4j.getExecutioner().exec(op)[0];
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth.
*/
public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) {
if (padMode == PadMode.CONSTANT) {
if (padWidth.length < toPad.rank())
throw new IllegalArgumentException("Please specify a pad width for each dimension");
toPad = Nd4j.stripOnes(toPad);
List<int[]> sizes = new ArrayList<>();
for (int i = 0; i < toPad.rank(); i++) {
sizes.add(padWidth);
}
return padImpl(toPad, sizes, constantValues);
}
throw new UnsupportedOperationException();
}
// common code for pad(INDArray, int[], List<double[]>, PadMode) and
// pad(INDArray, int[][], List<double[]>, PadMode)
private static INDArray padImpl(INDArray toPad, List<int[]> sizes, List<double[]> constantValues){
INDArray ret = toPad;
for (int i = 0; i < toPad.rank(); i++) {
int[] pad = sizes.get(i);
double[] constant = constantValues.get(i);
int padBefore = pad[0];
int padAfter = pad[1];
if (constant.length < 2) {
double val = constant[0];
constant = new double[2];
constant[0] = val;
constant[1] = val;
}
double beforeVal = constant[0];
double afterVal = constant[1];
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
ret = Nd4j.append(ret, padAfter, afterVal, i);
}
return ret;
}
/**
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth and zero padding.
*/
public static INDArray pad(INDArray toPad, int[] padWidth, PadMode padMode) {
return pad(toPad, padWidth, ArrayUtil.zerosMatrix(padWidth), padMode);
} }
/** /**
@ -2639,7 +2617,7 @@ public class Nd4j {
* @return the reversed matrix * @return the reversed matrix
*/ */
public static INDArray reverse(INDArray reverse) { public static INDArray reverse(INDArray reverse) {
return Nd4j.getExecutioner().exec(new OldReverse(reverse)); return Nd4j.getExecutioner().exec(new Reverse(reverse))[0];
} }
/** /**
@ -5962,27 +5940,6 @@ public class Nd4j {
} }
/**
* This method returns maximal allowed number of threads for Nd4j.
* If value wasn't set in advance, max(1, availableProcessor) will be returned
* @return maximal allowed number of threads
*/
public static int numThreads() {
val v = numThreads.get();
if (v <= 0)
return Math.max(1, Runtime.getRuntime().availableProcessors() / 2);
else
return v;
}
/**
* This method sets maximal allowed number of threads for Nd4j
* @param numthreads maximal allowed number of threads
*/
public static void setNumThreads(int numthreads) {
numThreads.set(numthreads);
}
public static DataType defaultFloatingPointType() { public static DataType defaultFloatingPointType() {
return defaultFloatingPointDataType.get(); return defaultFloatingPointDataType.get();
} }

View File

@ -20,7 +20,7 @@ import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -104,7 +104,7 @@ public class AdaMaxUpdater implements GradientUpdater<AdaMax> {
//u = max(B_2 * u, |grad|) //u = max(B_2 * u, |grad|)
u.muli(config.getBeta2()); u.muli(config.getBeta2());
Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later
Nd4j.getExecutioner().exec(new OldMax(u, gradient, u)); Nd4j.getExecutioner().exec(new Max(u, gradient, u));
double beta1t = FastMath.pow(config.getBeta1(), iteration + 1); double beta1t = FastMath.pow(config.getBeta1(), iteration + 1);

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
@ -105,6 +105,6 @@ public class NesterovsUpdater implements GradientUpdater<Nesterovs> {
INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1)); INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1));
gradient.assign(ret); gradient.assign(ret);
*/ */
Nd4j.getExecutioner().exec(new OldAddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient)); Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient));
} }
} }

View File

@ -30,6 +30,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNot;
import org.nd4j.linalg.api.ops.impl.shape.Cross; import org.nd4j.linalg.api.ops.impl.shape.Cross;
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot; import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.*; import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
@ -37,7 +41,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
@ -104,7 +107,7 @@ public class Transforms {
public static INDArray reverse(INDArray x, boolean dup) { public static INDArray reverse(INDArray x, boolean dup) {
return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x)); return Nd4j.getExecutioner().exec(new Reverse(x, dup ? x.ulike() : x))[0];
} }
/** /**
@ -140,14 +143,15 @@ public class Transforms {
/** /**
* Atan2 operation, new INDArray instance will be returned * Atan2 operation, new INDArray instance will be returned
* Note the order of x and y parameters is opposite to that of java.lang.Math.atan2 * Note the order of x and y parameters is opposite to that of {@link java.lang.Math#atan2(double, double)}
* *
* @param x the abscissa coordinate * @param x the abscissa coordinate
* @param y the ordinate coordinate * @param y the ordinate coordinate
* @return the theta from point (r, theta) when converting (x,y) from to cartesian to polar coordinates * @return the theta from point (r, theta) when converting (x,y) from to cartesian to polar coordinates
*/ */
public static INDArray atan2(@NonNull INDArray x, @NonNull INDArray y) { public static INDArray atan2(@NonNull INDArray x, @NonNull INDArray y) {
return Nd4j.getExecutioner().exec(new OldAtan2Op(x, y, x.ulike())); // Switched on purpose, to match OldATan2 (which the javadoc was written for)
return Nd4j.getExecutioner().exec(new ATan2(y, x, x.ulike()))[0];
} }
/** /**
@ -789,7 +793,7 @@ public class Transforms {
* @return * @return
*/ */
public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray, boolean dup) { public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
return exec(new OldLessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering()))); return Nd4j.getExecutioner().exec(new LessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0];
} }
@ -801,7 +805,7 @@ public class Transforms {
* @return * @return
*/ */
public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray, boolean dup) { public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
return exec(new OldGreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering()))); return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0];
} }
@ -986,7 +990,7 @@ public class Transforms {
* @return * @return
*/ */
public static INDArray identity(INDArray ndArray, boolean dup) { public static INDArray identity(INDArray ndArray, boolean dup) {
return exec(dup ? new OldIdentity(ndArray, ndArray.ulike()) : new OldIdentity(ndArray)); return Nd4j.getExecutioner().exec(dup ? new Identity(ndArray, ndArray.ulike()) : new Identity(ndArray, ndArray))[0];
} }
public static INDArray isMax(INDArray input, DataType dataType) { public static INDArray isMax(INDArray input, DataType dataType) {

View File

@ -16,6 +16,13 @@
package org.nd4j.autodiff.opvalidation; package org.nd4j.autodiff.opvalidation;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.Ignore; import org.junit.Ignore;
@ -32,21 +39,20 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
@Slf4j @Slf4j
public class LayerOpValidation extends BaseOpValidation { public class LayerOpValidation extends BaseOpValidation {
public LayerOpValidation(Nd4jBackend backend) { public LayerOpValidation(Nd4jBackend backend) {
@ -311,7 +317,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable loss = sd.mean("loss", out); SDVariable loss = sd.mean("loss", out);
log.info("Starting test: " + msg); log.info("Starting test: " + msg);
TestCase tc = new TestCase(sd); TestCase tc = new TestCase(sd).gradientCheck(true);
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if (error != null) { if (error != null) {
failed.add(msg); failed.add(msg);
@ -344,7 +350,7 @@ public class LayerOpValidation extends BaseOpValidation {
String msg = Arrays.toString(inSizeNCHW); String msg = Arrays.toString(inSizeNCHW);
TestCase tc = new TestCase(sd).testName(msg); TestCase tc = new TestCase(sd).gradientCheck(true).testName(msg);
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if (error != null) { if (error != null) {
failed.add(msg); failed.add(msg);
@ -552,7 +558,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable loss = sd.standardDeviation("loss", out, true); SDVariable loss = sd.standardDeviation("loss", out, true);
log.info("Starting test: " + msg); log.info("Starting test: " + msg);
TestCase tc = new TestCase(sd); TestCase tc = new TestCase(sd).gradientCheck(true);
tc.testName(msg); tc.testName(msg);
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if (error != null) { if (error != null) {
@ -660,7 +666,7 @@ public class LayerOpValidation extends BaseOpValidation {
// System.out.println(sd.getFunction("grad").summary()); // System.out.println(sd.getFunction("grad").summary());
//Gradient check: //Gradient check:
TestCase tc = new TestCase(sd); TestCase tc = new TestCase(sd).gradientCheck(true);
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
assertNull(err); assertNull(err);
} }
@ -705,7 +711,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable loss = out.std(true); SDVariable loss = out.std(true);
//Gradient check: //Gradient check:
TestCase tc = new TestCase(sd); TestCase tc = new TestCase(sd).gradientCheck(true);
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
assertNull(err); assertNull(err);
} }
@ -798,7 +804,7 @@ public class LayerOpValidation extends BaseOpValidation {
exp.putScalar(next, max); exp.putScalar(next, max);
} }
assertNull(OpValidation.validate(new TestCase(sd) assertNull(OpValidation.validate(new TestCase(sd).gradientCheck(true)
.expected(outPool, exp))); .expected(outPool, exp)));
} }
@ -856,7 +862,7 @@ public class LayerOpValidation extends BaseOpValidation {
} }
assertNull(OpValidation.validate(new TestCase(sd) assertNull(OpValidation.validate(new TestCase(sd)
.expected(outPool, exp))); .expected(outPool, exp).gradientCheck(true)));
} }
@ -887,16 +893,12 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig); SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("loss", out).shape().rename("out");
INDArray outArr = sd.execAndEndResult();
val outShape = outArr.shape();
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
assertArrayEquals(new long[]{mb, nIn, 4, 4, 4}, outShape); INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L);
SDVariable loss = out.std(true); TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true);
//Gradient check:
TestCase tc = new TestCase(sd);
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
assertNull(err); assertNull(err);
} }
@ -927,12 +929,16 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig); SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss");
INDArray outArr = sd.execAndEndResult();
val outShape = outArr.shape();
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
assertArrayEquals(new long[]{mb, nIn, 27, 27, 27}, outShape); INDArray outArr = Nd4j.createFromArray(mb, nIn, 27, 27, 27L);
TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true);
String err = OpValidation.validate(tc);
assertNull(err);
} }
@Test @Test
@ -958,13 +964,58 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss");
INDArray outArr = sd.execAndEndResult();
INDArray iOut = out.getArr();
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
val outShape = outArr.shape(); INDArray outArr = Nd4j.createFromArray(mb, nOut, 27L);
assertArrayEquals(new long[]{mb, nOut, 27}, outShape); TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(false);
String err = OpValidation
.validate(tc);
assertNull(err);
}
@Test
public void testConv1dForward(){
int nIn = 2;
int nOut = 1;
int kernel = 3;
int batchSize = 10;
int sequenceSize = 5;
SameDiff sd = SameDiff.create();
INDArray inArr = Nd4j.linspace(0, nIn * batchSize * sequenceSize, nIn * batchSize * sequenceSize)
.reshape(batchSize, nIn, sequenceSize);
INDArray wArr = Nd4j.linspace(0, kernel * nIn * nOut, kernel * nIn * nOut)
.reshape(kernel, nIn, nOut);
SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("w", wArr);
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build());
INDArray expected = Nd4j.createFromArray(
new double[][][]{
{{82.42424f, 100.60606f, 118.78788f}},
{{264.2424f, 282.4242f, 300.6061f}},
{{446.0606f, 464.2424f, 482.424f}},
{{627.8788f, 646.0606f, 664.2424f}},
{{809.6970f, 827.8788f, 846.0606f}},
{{991.5152f, 1009.69696f, 1027.8788f}},
{{1173.3333f, 1191.5152f, 1209.6970f}},
{{1355.1515f, 1373.3333f, 1391.5153f}},
{{1536.9697f, 1555.1515f, 1573.3333f}},
{{1718.7878f, 1736.9697f, 1755.1515f}}
}
);
TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.getVarName(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
} }
@ -1000,17 +1051,61 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig); SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss");
INDArray outArr = sd.execAndEndResult();
//Expected output size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27 //Expected output size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27
//Expected output size, WITH same mode: out = in/stride //Expected output size, WITH same mode: out = in/stride
val outShape = outArr.shape(); INDArray outArr = Nd4j.createFromArray(mb, nOut, 5, 5, 5L);
assertArrayEquals(new long[]{mb, nOut, 5, 5, 5}, outShape);
SDVariable loss = out.std(true); TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true);
//Gradient check: String err = OpValidation
TestCase tc = new TestCase(sd); .validate(tc);
assertNull(err);
}
@Test
public void testDeConv3dBasic() {
int nIn = 4;
int nOut = 3;
int kH = 2;
int kW = 2;
int kD = 2;
int mb = 3;
int imgH = 5;
int imgW = 5;
int imgT = 5;
SameDiff sd = SameDiff.create();
INDArray inArr = Nd4j.rand(new long[]{mb, nIn, 5, 5, 5});
INDArray wArr = Nd4j.rand(kD, kH, kW, nOut, nIn);
SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("W", wArr);
DeConv3DConfig conv3DConfig = DeConv3DConfig.builder()
.kH(kH).kW(kW).kD(kD)
.sD(1).sH(1).sW(1)
.dH(1).dW(1).dD(1)
.isSameMode(true)
.dataFormat(DeConv3DConfig.NCDHW)
.build();
SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig);
out = sd.nn().tanh("loss", out).shape().rename("out");
sd.setLossVariables("loss");
//Expected conv3d size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27
//Expected conv3d size, WITH same mode: out = in/stride
// reversed this for deconv3d
INDArray outArr = Nd4j.createFromArray(new long[]{mb, nOut, imgT, imgH, imgW});
TestCase tc = new TestCase(sd)
.expectedOutput("out", outArr)
.gradientCheck(true);
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
assertNull(err); assertNull(err);
} }
@ -1181,23 +1276,23 @@ public class LayerOpValidation extends BaseOpValidation {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (boolean ncdhw : new boolean[]{true, false}) { for (boolean ncdhw : new boolean[]{true, false}) {
int nIn = inSizeNCDHW[1]; int nIn = inSizeNCDHW[1];
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", shape); SDVariable in = sd.var("in", shape);
SDVariable out; SDVariable out;
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
.isSameMode(true) .isSameMode(true)
.kH(2).kW(2).kD(2) .kH(2).kW(2).kD(2)
.sD(1).sH(1).sW(-1).dW(-1) .sD(1).sH(1).sW(-1).dW(-1)
.build()); .build());
} }
} }

View File

@ -38,10 +38,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
import org.nd4j.linalg.api.ops.impl.shape.Cross; import org.nd4j.linalg.api.ops.impl.shape.Cross;
import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
@ -1008,7 +1008,7 @@ public class TransformOpValidation extends BaseOpValidation {
} }
DifferentialFunction[] funcs = sd.functions(); DifferentialFunction[] funcs = sd.ops();
String name = opName == null ? funcs[0].opName() : opName; String name = opName == null ? funcs[0].opName() : opName;
@ -1141,11 +1141,11 @@ public class TransformOpValidation extends BaseOpValidation {
break; break;
case 14: case 14:
t = sd.max(in1, in2); t = sd.max(in1, in2);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldMax(ia, ib, ia.dup()))); tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]);
break; break;
case 15: case 15:
t = sd.min(in1, in2); t = sd.min(in1, in2);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldMin(ia, ib, ia.dup()))); tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]);
break; break;
case 16: case 16:
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
@ -1199,7 +1199,7 @@ public class TransformOpValidation extends BaseOpValidation {
} }
DifferentialFunction[] funcs = sd.functions(); DifferentialFunction[] funcs = sd.ops();
String name = (opName == null ? funcs[0].opName() : opName); String name = (opName == null ? funcs[0].opName() : opName);
String msg = "test: " + i + " - " + name; String msg = "test: " + i + " - " + name;

View File

@ -188,11 +188,11 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName()); assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName());
} }
DifferentialFunction[] fOrig = sd.functions(); DifferentialFunction[] fOrig = sd.ops();
DifferentialFunction[] fRestored = restored.functions(); DifferentialFunction[] fRestored = restored.ops();
assertEquals(fOrig.length, fRestored.length); assertEquals(fOrig.length, fRestored.length);
for (int j = 0; j < sd.functions().length; j++) { for (int j = 0; j < sd.ops().length; j++) {
assertEquals(fOrig[j].getClass(), fRestored[j].getClass()); assertEquals(fOrig[j].getClass(), fRestored[j].getClass());
} }
@ -224,7 +224,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
sd.save(f2, withUpdaterState); sd.save(f2, withUpdaterState);
SameDiff r2 = SameDiff.load(f2, withUpdaterState); SameDiff r2 = SameDiff.load(f2, withUpdaterState);
assertEquals(varsOrig.size(), r2.variables().size()); assertEquals(varsOrig.size(), r2.variables().size());
assertEquals(fOrig.length, r2.functions().length); assertEquals(fOrig.length, r2.ops().length);
assertEquals(sd.getLossVariables(), r2.getLossVariables()); assertEquals(sd.getLossVariables(), r2.getLossVariables());
//Save via stream: //Save via stream:
@ -237,7 +237,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
try(InputStream is = new BufferedInputStream(new FileInputStream(f3))) { try(InputStream is = new BufferedInputStream(new FileInputStream(f3))) {
SameDiff r3 = SameDiff.load(is, withUpdaterState); SameDiff r3 = SameDiff.load(is, withUpdaterState);
assertEquals(varsOrig.size(), r3.variables().size()); assertEquals(varsOrig.size(), r3.variables().size());
assertEquals(fOrig.length, r3.functions().length); assertEquals(fOrig.length, r3.ops().length);
assertEquals(sd.getLossVariables(), r3.getLossVariables()); assertEquals(sd.getLossVariables(), r3.getLossVariables());
} }
} }

View File

@ -19,7 +19,6 @@ package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.Test; import org.junit.Test;
import org.nd4j.autodiff.samediff.transform.*; import org.nd4j.autodiff.samediff.transform.*;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -58,17 +57,17 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
SDVariable sub = add.sub(add2); SDVariable sub = add.sub(add2);
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(add.getVarName()))); assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.getVarName())));
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(add2.getVarName()))); assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.getVarName())));
assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(sub.getVarName()))); assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.getVarName())));
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(add.getVarName()))); assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.getVarName())));
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(add2.getVarName()))); assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.getVarName())));
assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(sub.getVarName()))); assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.getVarName())));
assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputFunction(add.getVarName()))); assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.getVarName())));
assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputFunction(add2.getVarName()))); assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.getVarName())));
assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputFunction(sub.getVarName()))); assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.getVarName())));
SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class)); SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class));
@ -77,11 +76,11 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
assertEquals(2, l.size()); assertEquals(2, l.size());
SubGraph sg1 = l.get(0); SubGraph sg1 = l.get(0);
assertTrue(sg1.getRootNode() == sd.getVariableOutputFunction(add.getVarName())); assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.getVarName()));
assertEquals(0, sg1.getChildNodes().size()); assertEquals(0, sg1.getChildNodes().size());
SubGraph sg2 = l.get(1); SubGraph sg2 = l.get(1);
assertTrue(sg2.getRootNode() == sd.getVariableOutputFunction(add2.getVarName())); assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.getVarName()));
assertEquals(0, sg2.getChildNodes().size()); assertEquals(0, sg2.getChildNodes().size());
} }

View File

@ -59,13 +59,13 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNorma
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing; import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor; import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing; import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
@ -1759,11 +1759,11 @@ public class SameDiffTests extends BaseNd4jTest {
break; break;
case 7: case 7:
t = sd.max(in1, in2); t = sd.max(in1, in2);
expOut = Nd4j.getExecutioner().exec(new OldMax(ia, ib, ia.dup())); expOut = Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0];
break; break;
case 8: case 8:
t = sd.min(in1, in2); t = sd.min(in1, in2);
expOut = Nd4j.getExecutioner().exec(new OldMin(ia, ib, ia.dup())); expOut = Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0];
break; break;
case 9: case 9:
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));

View File

@ -16,7 +16,6 @@
package org.nd4j.imports.TFGraphs; package org.nd4j.imports.TFGraphs;
import com.google.common.primitives.Doubles;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
@ -38,7 +37,6 @@ import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -46,7 +44,6 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.function.BiFunction; import org.nd4j.linalg.function.BiFunction;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
@ -301,7 +298,7 @@ public class TFGraphTestAllHelper {
Map<String,SameDiffOp> fns = graph.getOps(); Map<String,SameDiffOp> fns = graph.getOps();
List<String> execOrder = listener.getOpNamesList(); List<String> execOrder = listener.getOpNamesList();
for(String opName : execOrder){ for(String opName : execOrder){
String[] outputs = graph.getOutputsForFunction(fns.get(opName).getOp()); String[] outputs = graph.getOutputsForOp(fns.get(opName).getOp());
Collections.addAll(varNames, outputs); Collections.addAll(varNames, outputs);
} }
@ -334,8 +331,8 @@ public class TFGraphTestAllHelper {
if(countExceeds > 0){ if(countExceeds > 0){
maxRE = relError.maxNumber().doubleValue(); maxRE = relError.maxNumber().doubleValue();
//Find the op that this variable is produced by //Find the op that this variable is produced by
op = graph.getVariableOutputFunction(varName); op = graph.getVariableOutputOp(varName);
opInputs = graph.getInputsForFunction(op); opInputs = graph.getInputsForOp(op);
} }

View File

@ -732,9 +732,9 @@ public class TensorFlowImportTest extends BaseNd4jTest {
} }
val functions = new HashMap<String, DifferentialFunction>(); val functions = new HashMap<String, DifferentialFunction>();
for (val func: tg.functions()) { for (val func: tg.ops()) {
val ownName = func.getOwnName(); val ownName = func.getOwnName();
val outName = func.outputVariables()[0].getVarName(); String outName = func.outputVariables()[0].getVarName();
assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName)); assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
assertEquals(ownName, outName); assertEquals(ownName, outName);

View File

@ -72,12 +72,12 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND; import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
@ -5226,7 +5226,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
INDArray rev = Nd4j.getExecutioner().exec(new OldReverse(array, Nd4j.createUninitialized(array.length()))); INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0];
assertEquals(exp, rev); assertEquals(exp, rev);
} }
@ -5236,7 +5236,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
INDArray rev = Nd4j.getExecutioner().exec(new OldReverse(array, Nd4j.createUninitialized(array.length()))); INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0];
assertEquals(exp, rev); assertEquals(exp, rev);
} }

View File

@ -35,7 +35,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf;
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN; import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo; import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -276,7 +276,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.INT); val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.INT);
val exp = new long[]{1, 0, 0, 1}; val exp = new long[]{1, 0, 0, 1};
val result = Nd4j.getExecutioner().exec(new OldEqualTo(arrayX, arrayY)); val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0];
assertEquals(DataType.BOOL, result.dataType()); assertEquals(DataType.BOOL, result.dataType());
val arr = result.data().asLong(); val arr = result.data().asLong();
@ -369,13 +369,13 @@ public class MixedDataTypesTests extends BaseNd4jTest {
val result = Nd4j.getExecutioner().exec(op); val result = Nd4j.getExecutioner().exec(op);
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = RuntimeException.class)
public void testTypesValidation_2() { public void testTypesValidation_2() {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG); val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG);
val exp = new long[]{1, 0, 0, 1}; val exp = new long[]{1, 0, 0, 1};
val result = Nd4j.getExecutioner().exec(new OldEqualTo(arrayX, arrayY)); val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0];
val arr = result.data().asLong(); val arr = result.data().asLong();
assertArrayEquals(exp, arr); assertArrayEquals(exp, arr);

View File

@ -45,7 +45,7 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange; import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
@ -205,7 +205,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
INDArray x = Nd4j.ones(5); INDArray x = Nd4j.ones(5);
INDArray xDup = x.dup(); INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 1.0); INDArray solution = Nd4j.valueArrayOf(5, 1.0);
opExecutioner.exec(new OldMulOp(x, xDup, x)); opExecutioner.exec(new MulOp(x, xDup, x));
assertEquals(solution, x); assertEquals(solution, x);
} }

View File

@ -55,7 +55,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange; import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
@ -236,7 +236,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
INDArray x = Nd4j.ones(5); INDArray x = Nd4j.ones(5);
INDArray xDup = x.dup(); INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 1.0); INDArray solution = Nd4j.valueArrayOf(5, 1.0);
opExecutioner.exec(new OldMulOp(x, xDup, x)); opExecutioner.exec(new MulOp(x, xDup, x));
assertEquals(solution, x); assertEquals(solution, x);
} }

View File

@ -72,8 +72,9 @@ public class PaddingTests extends BaseNd4jTest {
@Test @Test
public void testPad() { public void testPad() {
INDArray start = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray start = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray ret = Nd4j.pad(start, new int[] {5, 5}, Nd4j.PadMode.CONSTANT); INDArray ret = Nd4j.pad(start, 5, 5);
double[][] data = new double[][] {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, double[][] data = new double[][] {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},

View File

@ -64,8 +64,7 @@ public class PaddingTestsC extends BaseNd4jTest {
INDArray ret = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, INDArray ret = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8}); 4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8});
INDArray padded = Nd4j.pad(ret, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, INDArray padded = Nd4j.pad(ret, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}});
Nd4j.PadMode.CONSTANT);
INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3, INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3,
3, 3, 3, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0,
@ -104,8 +103,7 @@ public class PaddingTestsC extends BaseNd4jTest {
// FIXME: int cast // FIXME: int cast
int outWidth = Convolution.outSize((int) h, kh, sy, ph, 1, true); int outWidth = Convolution.outSize((int) h, kh, sy, ph, 1, true);
int outHeight = Convolution.outSize((int) w, kw, sx, pw, 1, true); int outHeight = Convolution.outSize((int) w, kw, sx, pw, 1, true);
INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}});
Nd4j.PadMode.CONSTANT);
System.out.println(padded); System.out.println(padded);
} }

View File

@ -127,8 +127,7 @@ public class IndexingTestsC extends BaseNd4jTest {
4, 4, 4, 4, 4, 4, 4, 4}, new long[] {1, 1, 8, 8}); 4, 4, 4, 4, 4, 4, 4, 4}, new long[] {1, 1, 8, 8});
INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}});
Nd4j.PadMode.CONSTANT);
INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim), INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim),
NDArrayIndex.interval(j, sx, jLim)); NDArrayIndex.interval(j, sx, jLim));