[WIP] Various fixes, mostly SameDiff/Nd4j (#110)
* Nd4j pad update Signed-off-by: Ryan Nett <rnett@skymind.io> * switched from guava Immutables to Collections.unmodifiableList/Map Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * use new pad Signed-off-by: Ryan Nett <rnett@skymind.io> * conv tests use OpValidation Signed-off-by: Ryan Nett <rnett@skymind.io> * deconv3d overrides Signed-off-by: Ryan Nett <rnett@skymind.io> * test fix for the new pad method Signed-off-by: Ryan Nett <rnett@skymind.io> * more test fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more test fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * rename SameDiff function methods to op (except for the actual SameDiff function ones) Signed-off-by: Ryan Nett <rnett@skymind.io> * more pad overloads, test fix Signed-off-by: Ryan Nett <rnett@skymind.io> * test updates Signed-off-by: Ryan Nett <rnett@skymind.io> * conv1d test Signed-off-by: Ryan Nett <rnett@skymind.io> * remove Conv1D tf import (there isn't a TF conv1d op) Signed-off-by: Ryan Nett <rnett@skymind.io> * remove numThreads from Nd4j Signed-off-by: Ryan Nett <rnett@skymind.io> * replace Old ops with their newer versions, deprecate ones that haven't already been deprecated Signed-off-by: Ryan Nett <rnett@skymind.io> * remove use of setNumThreads Signed-off-by: Ryan Nett <rnett@skymind.io> * fix for Reverse and ATan2 Signed-off-by: Ryan Nett <rnett@skymind.io> * fix test for wrong equals type Signed-off-by: Ryan Nett <rnett@skymind.io> * well it works now Signed-off-by: Ryan Nett <rnett@skymind.io> * better javadocs Signed-off-by: Ryan Nett <rnett@skymind.io> * NonNulls Signed-off-by: Ryan Nett <rnett@skymind.io> * better array literal Signed-off-by: Ryan Nett <rnett@skymind.io> * re-add tf import stuff (will remove later) Signed-off-by: Ryan Nett <rnett@skymind.io> * conv1d config load fix Signed-off-by: Ryan Nett <rnett@skymind.io> * partial config usage changes Signed-off-by: Ryan Nett <rnett@skymind.io> * remove Old op classes Signed-off-by: Ryan Nett <rnett@skymind.io> * config property fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * removed one too many ops Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
eea3062ccf
commit
2b0d7b3b52
|
@ -24,11 +24,10 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
|
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
@ -139,7 +138,7 @@ public class AlphaDropout implements IDropout {
|
||||||
//a * (x * d + alphaPrime * (1-d)) + b
|
//a * (x * d + alphaPrime * (1-d)) + b
|
||||||
INDArray inverseMask = mask.rsub(1.0);
|
INDArray inverseMask = mask.rsub(1.0);
|
||||||
INDArray aPOneMinusD = inverseMask.muli(alphaPrime);
|
INDArray aPOneMinusD = inverseMask.muli(alphaPrime);
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, mask, output)); //out = x * d
|
Nd4j.getExecutioner().exec(new MulOp(inputActivations, mask, output)); //out = x * d
|
||||||
output.addi(aPOneMinusD).muli(a).addi(b);
|
output.addi(aPOneMinusD).muli(a).addi(b);
|
||||||
|
|
||||||
//Nd4j.getExecutioner().exec(new AlphaDropOut(inputActivations, output, p, a, alphaPrime, b));
|
//Nd4j.getExecutioner().exec(new AlphaDropOut(inputActivations, output, p, a, alphaPrime, b));
|
||||||
|
@ -152,7 +151,7 @@ public class AlphaDropout implements IDropout {
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
// dOut/dIn = 0 if dropped (d=0), or a otherwise (d=1)
|
// dOut/dIn = 0 if dropped (d=0), or a otherwise (d=1)
|
||||||
mask.muli(a);
|
mask.muli(a);
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, mask, gradAtInput));
|
Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, mask, gradAtInput));
|
||||||
mask = null;
|
mask = null;
|
||||||
return gradAtInput;
|
return gradAtInput;
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
|
@ -153,7 +153,7 @@ public class Dropout implements IDropout {
|
||||||
|
|
||||||
mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()).assign(1.0);
|
mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()).assign(1.0);
|
||||||
Nd4j.getExecutioner().exec(new DropOutInverted(mask, mask, currP));
|
Nd4j.getExecutioner().exec(new DropOutInverted(mask, mask, currP));
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(inputCast, mask, output));
|
Nd4j.getExecutioner().exec(new MulOp(inputCast, mask, output));
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,7 +171,7 @@ public class Dropout implements IDropout {
|
||||||
if(m.dataType() != gradAtInput.dataType()){
|
if(m.dataType() != gradAtInput.dataType()){
|
||||||
m = m.castTo(gradAtInput.dataType());
|
m = m.castTo(gradAtInput.dataType());
|
||||||
}
|
}
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, m, gradAtInput));
|
Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, m, gradAtInput));
|
||||||
mask = null;
|
mask = null;
|
||||||
return gradAtInput;
|
return gradAtInput;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
|
@ -88,7 +88,7 @@ public class GaussianDropout implements IDropout {
|
||||||
noise = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), inputActivations.shape(), inputActivations.ordering());
|
noise = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), inputActivations.shape(), inputActivations.ordering());
|
||||||
Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 1.0, stdev));
|
Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 1.0, stdev));
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new OldMulOp(inputActivations, noise, output));
|
return Nd4j.getExecutioner().exec(new MulOp(inputActivations, noise, output))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -96,7 +96,7 @@ public class GaussianDropout implements IDropout {
|
||||||
Preconditions.checkState(noise != null, "Cannot perform backprop: GaussianDropout noise array is absent (already cleared?)");
|
Preconditions.checkState(noise != null, "Cannot perform backprop: GaussianDropout noise array is absent (already cleared?)");
|
||||||
//out = in*y, where y ~ N(1, stdev)
|
//out = in*y, where y ~ N(1, stdev)
|
||||||
//dL/dIn = dL/dOut * dOut/dIn = y * dL/dOut
|
//dL/dIn = dL/dOut * dOut/dIn = y * dL/dOut
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(gradAtOutput, noise, gradAtInput));
|
Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, noise, gradAtInput));
|
||||||
noise = null;
|
noise = null;
|
||||||
return gradAtInput;
|
return gradAtInput;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.deeplearning4j.nn.conf.dropout;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
|
@ -69,7 +69,7 @@ public class GaussianNoise implements IDropout {
|
||||||
INDArray noise = Nd4j.createUninitialized(output.dataType(), inputActivations.shape(), inputActivations.ordering());
|
INDArray noise = Nd4j.createUninitialized(output.dataType(), inputActivations.shape(), inputActivations.ordering());
|
||||||
Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 0, currS));
|
Nd4j.getExecutioner().exec(new GaussianDistribution(noise, 0, currS));
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new OldAddOp(inputActivations, noise, output));
|
Nd4j.getExecutioner().exec(new AddOp(inputActivations, noise, output));
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
@ -144,7 +144,7 @@ public class BernoulliReconstructionDistribution implements ReconstructionDistri
|
||||||
|
|
||||||
INDArray out = Nd4j.createUninitialized(DataType.BOOL, p.shape());
|
INDArray out = Nd4j.createUninitialized(DataType.BOOL, p.shape());
|
||||||
|
|
||||||
Nd4j.getExecutioner().execAndReturn(new OldLessThan(rand, p, out));
|
Nd4j.getExecutioner().execAndReturn(new LessThan(rand, p, out));
|
||||||
return out.castTo(DataType.FLOAT);
|
return out.castTo(DataType.FLOAT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,8 @@ import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distributions;
|
import org.deeplearning4j.nn.conf.distribution.Distributions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
@ -86,9 +86,9 @@ public class WeightNoise implements IWeightNoise {
|
||||||
INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering());
|
INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering());
|
||||||
|
|
||||||
if (additive) {
|
if (additive) {
|
||||||
Nd4j.getExecutioner().exec(new OldAddOp(param, noise,out));
|
Nd4j.getExecutioner().exec(new AddOp(param, noise,out));
|
||||||
} else {
|
} else {
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(param, noise, out));
|
Nd4j.getExecutioner().exec(new MulOp(param, noise, out));
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,8 +34,8 @@ import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -205,7 +205,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
INDArray batchMean = helper.getMeanCache(dataType);
|
INDArray batchMean = helper.getMeanCache(dataType);
|
||||||
INDArray batchVar = helper.getVarCache(dataType);
|
INDArray batchVar = helper.getVarCache(dataType);
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new OldSubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean
|
Nd4j.getExecutioner().exec(new SubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean
|
||||||
dGlobalMeanView.muli(1-layerConf().getDecay());
|
dGlobalMeanView.muli(1-layerConf().getDecay());
|
||||||
|
|
||||||
if(layerConf().isUseLogStd()){
|
if(layerConf().isUseLogStd()){
|
||||||
|
@ -219,12 +219,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
|
|
||||||
double decay = layerConf().getDecay();
|
double decay = layerConf().getDecay();
|
||||||
INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay));
|
INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay));
|
||||||
Nd4j.getExecutioner().exec(new OldDivOp(vari, varip1, dGlobalLog10StdView));
|
Nd4j.getExecutioner().exec(new DivOp(vari, varip1, dGlobalLog10StdView));
|
||||||
Transforms.log(dGlobalLog10StdView, false);
|
Transforms.log(dGlobalLog10StdView, false);
|
||||||
dGlobalLog10StdView.muli(ONE_ON_2LOGE_10);
|
dGlobalLog10StdView.muli(ONE_ON_2LOGE_10);
|
||||||
} else {
|
} else {
|
||||||
//Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3
|
//Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3
|
||||||
Nd4j.getExecutioner().exec(new OldSubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar
|
Nd4j.getExecutioner().exec(new SubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar
|
||||||
dGlobalVarView.muli(1 - layerConf().getDecay());
|
dGlobalVarView.muli(1 - layerConf().getDecay());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,7 +343,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
And use the same idea for global variance estimate
|
And use the same idea for global variance estimate
|
||||||
*/
|
*/
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new OldSubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean
|
Nd4j.getExecutioner().exec(new SubOp(globalMean, batchMean, dGlobalMeanView)); //deltaGlobalMean = globalMean[t] - batchMean
|
||||||
dGlobalMeanView.muli(1-layerConf().getDecay());
|
dGlobalMeanView.muli(1-layerConf().getDecay());
|
||||||
|
|
||||||
if(layerConf().isUseLogStd()){
|
if(layerConf().isUseLogStd()){
|
||||||
|
@ -357,12 +357,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
|
|
||||||
double decay = layerConf().getDecay();
|
double decay = layerConf().getDecay();
|
||||||
INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay));
|
INDArray varip1 = vari.mul(decay).addi(batchVar.mul(1-decay));
|
||||||
Nd4j.getExecutioner().exec(new OldDivOp(vari, varip1, dGlobalLog10StdView));
|
Nd4j.getExecutioner().exec(new DivOp(vari, varip1, dGlobalLog10StdView));
|
||||||
Transforms.log(dGlobalLog10StdView, false);
|
Transforms.log(dGlobalLog10StdView, false);
|
||||||
dGlobalLog10StdView.muli(ONE_ON_2LOGE_10);
|
dGlobalLog10StdView.muli(ONE_ON_2LOGE_10);
|
||||||
} else {
|
} else {
|
||||||
//Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3
|
//Use variance estimate parameterization. This was only option up to and including 1.0.0-beta3
|
||||||
Nd4j.getExecutioner().exec(new OldSubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar
|
Nd4j.getExecutioner().exec(new SubOp(globalVar, batchVar, dGlobalVarView)); //deltaGlobalVar = globalVar[t] - batchVar
|
||||||
dGlobalVarView.muli(1 - layerConf().getDecay());
|
dGlobalVarView.muli(1 - layerConf().getDecay());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,10 +23,9 @@ import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.AbstractLayer;
|
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLocalResponseNormalizationHelper;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
|
@ -187,7 +186,7 @@ public class LocalResponseNormalization
|
||||||
|
|
||||||
// gx = gy * unitScale**-beta - 2 * alpha * beta * sumPart/unitScale * a^i_{x,y} - rearranged for more in-place ops
|
// gx = gy * unitScale**-beta - 2 * alpha * beta * sumPart/unitScale * a^i_{x,y} - rearranged for more in-place ops
|
||||||
INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.shape(), epsilon.ordering());
|
INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsilon.shape(), epsilon.ordering());
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(epsilon, scale, nextEpsilon));
|
Nd4j.getExecutioner().exec(new MulOp(epsilon, scale, nextEpsilon));
|
||||||
nextEpsilon.subi(sumPart.muli(input).divi(unitScale).muli(2 * alpha * beta));
|
nextEpsilon.subi(sumPart.muli(input).divi(unitScale).muli(2 * alpha * beta));
|
||||||
return new Pair<>(retGradient, nextEpsilon);
|
return new Pair<>(retGradient, nextEpsilon);
|
||||||
}
|
}
|
||||||
|
@ -257,7 +256,7 @@ public class LocalResponseNormalization
|
||||||
unitScale = sumPart.mul(alpha).addi(k);
|
unitScale = sumPart.mul(alpha).addi(k);
|
||||||
// y = x * unitScale**-beta
|
// y = x * unitScale**-beta
|
||||||
scale = Transforms.pow(unitScale, -beta, true);
|
scale = Transforms.pow(unitScale, -beta, true);
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(input, scale, activations));
|
Nd4j.getExecutioner().exec(new MulOp(input, scale, activations));
|
||||||
} else {
|
} else {
|
||||||
// unitScale = (k + alpha * sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 )
|
// unitScale = (k + alpha * sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 )
|
||||||
sumPart.muli(alpha, activations).addi(k);
|
sumPart.muli(alpha, activations).addi(k);
|
||||||
|
|
|
@ -35,7 +35,7 @@ import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -554,7 +554,7 @@ public class LSTMHelpers {
|
||||||
|
|
||||||
//Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
|
//Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
|
||||||
INDArray deltao = deltaoNext;
|
INDArray deltao = deltaoNext;
|
||||||
Nd4j.getExecutioner().exec(new OldMulOp(nablaOut, sigmahOfS, deltao));
|
Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));
|
||||||
if (sigmoidGates) {
|
if (sigmoidGates) {
|
||||||
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().exec(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo
|
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().exec(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo
|
||||||
deltao.muli(sigmaoPrimeOfZo);
|
deltao.muli(sigmaoPrimeOfZo);
|
||||||
|
@ -607,7 +607,7 @@ public class LSTMHelpers {
|
||||||
deltag.muli(ai);
|
deltag.muli(ai);
|
||||||
deltag.muli(nablaCellState);
|
deltag.muli(nablaCellState);
|
||||||
} else {
|
} else {
|
||||||
INDArray temp2 = Nd4j.getExecutioner().exec(new OldMulOp(ai, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), ai.shape(), 'f')));
|
INDArray temp2 = Nd4j.getExecutioner().exec(new MulOp(ai, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), ai.shape(), 'f')))[0];
|
||||||
deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
|
deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
|
||||||
//TODO activation functions with params; optimize (no assign)
|
//TODO activation functions with params; optimize (no assign)
|
||||||
}
|
}
|
||||||
|
@ -616,7 +616,7 @@ public class LSTMHelpers {
|
||||||
//Network input delta:
|
//Network input delta:
|
||||||
INDArray zi = fwdPass.iz[time];
|
INDArray zi = fwdPass.iz[time];
|
||||||
INDArray deltai = deltaiNext;
|
INDArray deltai = deltaiNext;
|
||||||
temp = Nd4j.getExecutioner().exec(new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), deltai.shape(), 'f')));
|
temp = Nd4j.getExecutioner().exec(new MulOp(ag, nablaCellState, Nd4j.createUninitialized(inputWeights.dataType(), deltai.shape(), 'f')))[0];
|
||||||
deltai.assign(afn.backprop(zi, temp).getFirst());
|
deltai.assign(afn.backprop(zi, temp).getFirst());
|
||||||
//TODO activation functions with params; also: optimize this (no assign)
|
//TODO activation functions with params; also: optimize this (no assign)
|
||||||
//Shape: [m,n^L]
|
//Shape: [m,n^L]
|
||||||
|
|
|
@ -26,7 +26,6 @@ import onnx.OnnxProto3;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.graph.DataType;
|
|
||||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
|
@ -520,7 +519,7 @@ public abstract class DifferentialFunction {
|
||||||
* @return the arguments for a given function
|
* @return the arguments for a given function
|
||||||
*/
|
*/
|
||||||
public SDVariable[] args() {
|
public SDVariable[] args() {
|
||||||
return sameDiff.getInputVariablesForFunction(this);
|
return sameDiff.getInputVariablesForOp(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -661,7 +660,7 @@ public abstract class DifferentialFunction {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(sameDiff != null && !(this instanceof SDVariable))
|
if(sameDiff != null && !(this instanceof SDVariable))
|
||||||
sameDiff.putFunctionForId(ownName,this);
|
sameDiff.putOpForId(ownName,this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import com.google.common.collect.Lists;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -39,12 +40,12 @@ import org.nd4j.evaluation.IMetric;
|
||||||
@Getter
|
@Getter
|
||||||
public class EvaluationRecord {
|
public class EvaluationRecord {
|
||||||
|
|
||||||
private ImmutableMap<String, List<IEvaluation>> evaluations;
|
private Map<String, List<IEvaluation>> evaluations;
|
||||||
private Map<Class<? extends IEvaluation>, IEvaluation> classEvaluations = new HashMap<>();
|
private Map<Class<? extends IEvaluation>, IEvaluation> classEvaluations = new HashMap<>();
|
||||||
private boolean isEmpty = true;
|
private boolean isEmpty = true;
|
||||||
|
|
||||||
public EvaluationRecord(Map<String, List<IEvaluation>> evaluations) {
|
public EvaluationRecord(Map<String, List<IEvaluation>> evaluations) {
|
||||||
this.evaluations = ImmutableMap.copyOf(evaluations);
|
this.evaluations = Collections.unmodifiableMap(evaluations);
|
||||||
|
|
||||||
for (List<IEvaluation> le : evaluations.values()) {
|
for (List<IEvaluation> le : evaluations.values()) {
|
||||||
for (IEvaluation e : le) {
|
for (IEvaluation e : le) {
|
||||||
|
@ -68,7 +69,7 @@ public class EvaluationRecord {
|
||||||
/**
|
/**
|
||||||
* Get all evaluations
|
* Get all evaluations
|
||||||
*/
|
*/
|
||||||
public ImmutableMap<String, List<IEvaluation>> evaluations() {
|
public Map<String, List<IEvaluation>> evaluations() {
|
||||||
return evaluations;
|
return evaluations;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.listeners.records;
|
package org.nd4j.autodiff.listeners.records;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableList;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -49,11 +49,11 @@ public class History {
|
||||||
|
|
||||||
public History(List<EvaluationRecord> training, List<EvaluationRecord> validation, LossCurve loss,
|
public History(List<EvaluationRecord> training, List<EvaluationRecord> validation, LossCurve loss,
|
||||||
long trainingTimeMillis, List<Long> validationTimesMillis){
|
long trainingTimeMillis, List<Long> validationTimesMillis){
|
||||||
trainingHistory = ImmutableList.copyOf(training);
|
trainingHistory = Collections.unmodifiableList(training);
|
||||||
validationHistory = ImmutableList.copyOf(validation);
|
validationHistory = Collections.unmodifiableList(validation);
|
||||||
this.lossCurve = loss;
|
this.lossCurve = loss;
|
||||||
this.trainingTimeMillis = trainingTimeMillis;
|
this.trainingTimeMillis = trainingTimeMillis;
|
||||||
this.validationTimesMillis = ImmutableList.copyOf(validationTimesMillis);
|
this.validationTimesMillis = Collections.unmodifiableList(validationTimesMillis);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.listeners.records;
|
package org.nd4j.autodiff.listeners.records;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableList;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -35,7 +35,7 @@ public class LossCurve {
|
||||||
private INDArray lossValues;
|
private INDArray lossValues;
|
||||||
|
|
||||||
public LossCurve(List<Loss> losses){
|
public LossCurve(List<Loss> losses){
|
||||||
lossNames = ImmutableList.copyOf(losses.get(0).getLossNames());
|
lossNames = Collections.unmodifiableList(losses.get(0).getLossNames());
|
||||||
int numLossValues = losses.get(0).lossValues().length;
|
int numLossValues = losses.get(0).lossValues().length;
|
||||||
lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length);
|
lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length);
|
||||||
|
|
||||||
|
|
|
@ -466,8 +466,11 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the current {@link Listener} instances.
|
* Set the current SameDiff-wide {@link Listener} instances.
|
||||||
* Note that
|
*
|
||||||
|
* Note that this will overwrite the current listener list.
|
||||||
|
* If you want to use additional listeners for a single operation,
|
||||||
|
* use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
|
||||||
*
|
*
|
||||||
* @param listeners Listeners
|
* @param listeners Listeners
|
||||||
*/
|
*/
|
||||||
|
@ -476,19 +479,37 @@ public class SameDiff extends SDBaseOps {
|
||||||
addListeners(listeners);
|
addListeners(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #setListeners(Listener...)}.
|
||||||
|
*/
|
||||||
public void setListeners(Collection<? extends Listener> listeners) {
|
public void setListeners(Collection<? extends Listener> listeners) {
|
||||||
this.listeners.clear();
|
this.listeners.clear();
|
||||||
addListeners(listeners);
|
addListeners(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add SameDiff-wide {@link Listener} instances.
|
||||||
|
*
|
||||||
|
* If you want to use additional listeners for a single operation,
|
||||||
|
* use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
|
||||||
|
*
|
||||||
|
* @param listeners Listeners
|
||||||
|
*/
|
||||||
public void addListeners(Listener... listeners) {
|
public void addListeners(Listener... listeners) {
|
||||||
addListeners(Arrays.asList(listeners));
|
addListeners(Arrays.asList(listeners));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #addListeners(Listener...)}.
|
||||||
|
*/
|
||||||
public void addListeners(Collection<? extends Listener> listeners) {
|
public void addListeners(Collection<? extends Listener> listeners) {
|
||||||
this.listeners.addAll(listeners);
|
this.listeners.addAll(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the current SameDiff-wide listeners.
|
||||||
|
*/
|
||||||
public List<Listener> getListeners() {
|
public List<Listener> getListeners() {
|
||||||
return listeners;
|
return listeners;
|
||||||
}
|
}
|
||||||
|
@ -585,6 +606,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets all operations in a given name scope.
|
||||||
|
*/
|
||||||
public List<SameDiffOp> getOpsInScope(NameScope scope) {
|
public List<SameDiffOp> getOpsInScope(NameScope scope) {
|
||||||
ArrayList<SameDiffOp> ops = new ArrayList<>();
|
ArrayList<SameDiffOp> ops = new ArrayList<>();
|
||||||
for (SameDiffOp v : this.ops.values()) {
|
for (SameDiffOp v : this.ops.values()) {
|
||||||
|
@ -594,6 +618,16 @@ public class SameDiff extends SDBaseOps {
|
||||||
return ops;
|
return ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #getOpsInScope(NameScope)}.
|
||||||
|
*/
|
||||||
|
public List<SameDiffOp> getOpsInScope(String scope){
|
||||||
|
return getOpsInScope(new NameScope(this, scope));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets all variables in a given name scope.
|
||||||
|
*/
|
||||||
public List<SDVariable> getVariablesInScope(NameScope scope) {
|
public List<SDVariable> getVariablesInScope(NameScope scope) {
|
||||||
ArrayList<SDVariable> vars = new ArrayList<>();
|
ArrayList<SDVariable> vars = new ArrayList<>();
|
||||||
for (SDVariable v : variables()) {
|
for (SDVariable v : variables()) {
|
||||||
|
@ -603,6 +637,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
return vars;
|
return vars;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #getVariablesInScope(NameScope)}.
|
||||||
|
*/
|
||||||
|
public List<SDVariable> getVariablesInScope(String scope){
|
||||||
|
return getVariablesInScope(new NameScope(this, scope));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param sameDiff
|
* @param sameDiff
|
||||||
* @return
|
* @return
|
||||||
|
@ -638,8 +679,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
function.getSameDiff());
|
function.getSameDiff());
|
||||||
clone.setSameDiff(sameDiff);
|
clone.setSameDiff(sameDiff);
|
||||||
clone.setOwnName(function.getOwnName());
|
clone.setOwnName(function.getOwnName());
|
||||||
if (sameDiff.functionExists(function.getOwnName()))
|
if (sameDiff.opExists(function.getOwnName()))
|
||||||
sameDiff.putFunctionForId(function.getOwnName(), function);
|
sameDiff.putOpForId(function.getOwnName(), function);
|
||||||
newFunctions.put(function.getOwnName(), clone);
|
newFunctions.put(function.getOwnName(), clone);
|
||||||
|
|
||||||
val argsForFunction = function.args();
|
val argsForFunction = function.args();
|
||||||
|
@ -672,17 +713,21 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param id the function id to test for
|
* @param id the function id to test for
|
||||||
* @return true if the function id exists, false otherwise
|
* @return true if the function id exists, false otherwise
|
||||||
*/
|
*/
|
||||||
public boolean functionExists(String id) {
|
public boolean opExists(String id) {
|
||||||
return ops.containsKey(id);
|
return ops.containsKey(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DifferentialFunction functionOutputFor(String varName) {
|
/**
|
||||||
if (variables.get(varName).getOutputOfOp() == null)
|
* Get the differential function (if any) that this variable is the output for
|
||||||
|
*
|
||||||
|
* @param variableName Name of the variable
|
||||||
|
* @return The differential function that this variable is an output of, or null if it is not the output of a function
|
||||||
|
*/
|
||||||
|
public DifferentialFunction getVariableOutputOp(String variableName) {
|
||||||
|
Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
|
||||||
|
if (variables.get(variableName).getOutputOfOp() == null)
|
||||||
return null;
|
return null;
|
||||||
String outName = variables.get(varName).getOutputOfOp();
|
return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
|
||||||
if (outName == null)
|
|
||||||
return null;
|
|
||||||
return ops.get(outName).getOp();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -691,7 +736,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param id the id of the function
|
* @param id the id of the function
|
||||||
* @return the function for the given id if it exists
|
* @return the function for the given id if it exists
|
||||||
*/
|
*/
|
||||||
public DifferentialFunction getFunctionById(@NonNull String id) {
|
public DifferentialFunction getOpById(@NonNull String id) {
|
||||||
if (!ops.containsKey(id)) {
|
if (!ops.containsKey(id)) {
|
||||||
throw new ND4JIllegalStateException("No function with id " + id + " found!");
|
throw new ND4JIllegalStateException("No function with id " + id + " found!");
|
||||||
}
|
}
|
||||||
|
@ -705,7 +750,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param id the id of the function
|
* @param id the id of the function
|
||||||
* @param function the function
|
* @param function the function
|
||||||
*/
|
*/
|
||||||
public void putFunctionForId(String id, DifferentialFunction function) {
|
public void putOpForId(String id, DifferentialFunction function) {
|
||||||
if (ops.containsKey(id) && ops.get(id).getOp() == null) {
|
if (ops.containsKey(id) && ops.get(id).getOp() == null) {
|
||||||
throw new ND4JIllegalStateException("Function by id already exists!");
|
throw new ND4JIllegalStateException("Function by id already exists!");
|
||||||
} else if (function instanceof SDVariable) {
|
} else if (function instanceof SDVariable) {
|
||||||
|
@ -726,7 +771,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param function the function to get the inputs for
|
* @param function the function to get the inputs for
|
||||||
* @return the input ids for a given function
|
* @return the input ids for a given function
|
||||||
*/
|
*/
|
||||||
public String[] getInputsForFunction(DifferentialFunction function) {
|
public String[] getInputsForOp(DifferentialFunction function) {
|
||||||
if (!ops.containsKey(function.getOwnName()))
|
if (!ops.containsKey(function.getOwnName()))
|
||||||
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
|
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
|
||||||
List<String> inputs = ops.get(function.getOwnName()).getInputsToOp();
|
List<String> inputs = ops.get(function.getOwnName()).getInputsToOp();
|
||||||
|
@ -739,7 +784,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param function the function to get the outputs for
|
* @param function the function to get the outputs for
|
||||||
* @return the outputs ids for a given function
|
* @return the outputs ids for a given function
|
||||||
*/
|
*/
|
||||||
public String[] getOutputsForFunction(DifferentialFunction function) {
|
public String[] getOutputsForOp(DifferentialFunction function) {
|
||||||
if (!ops.containsKey(function.getOwnName()))
|
if (!ops.containsKey(function.getOwnName()))
|
||||||
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
|
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
|
||||||
List<String> outputs = ops.get(function.getOwnName()).getOutputsOfOp();
|
List<String> outputs = ops.get(function.getOwnName()).getOutputsOfOp();
|
||||||
|
@ -753,8 +798,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param function the function reference to get the output variable(s) for
|
* @param function the function reference to get the output variable(s) for
|
||||||
* @return the output variables for the given function
|
* @return the output variables for the given function
|
||||||
*/
|
*/
|
||||||
public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) {
|
public SDVariable[] getOutputVariablesForOp(DifferentialFunction function) {
|
||||||
val inputs = getOutputsForFunction(function);
|
val inputs = getOutputsForOp(function);
|
||||||
if (inputs == null) {
|
if (inputs == null) {
|
||||||
throw new ND4JIllegalStateException("No inputs found for function " + function);
|
throw new ND4JIllegalStateException("No inputs found for function " + function);
|
||||||
}
|
}
|
||||||
|
@ -774,8 +819,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param function the function reference to get the input variable(s) for
|
* @param function the function reference to get the input variable(s) for
|
||||||
* @return the input variables for the given function
|
* @return the input variables for the given function
|
||||||
*/
|
*/
|
||||||
public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) {
|
public SDVariable[] getInputVariablesForOp(DifferentialFunction function) {
|
||||||
val inputs = getInputsForFunction(function);
|
val inputs = getInputsForOp(function);
|
||||||
if (inputs == null) {
|
if (inputs == null) {
|
||||||
throw new ND4JIllegalStateException("No inputs found for function " + function);
|
throw new ND4JIllegalStateException("No inputs found for function " + function);
|
||||||
}
|
}
|
||||||
|
@ -792,6 +837,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the stored {@link INDArray} for a variable. Only works if the variable is of type
|
||||||
|
* {@link VariableType#CONSTANT}, {@link VariableType#PLACEHOLDER}, or {@link VariableType#VARIABLE}.
|
||||||
|
*/
|
||||||
public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) {
|
public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) {
|
||||||
Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName);
|
Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName);
|
||||||
|
|
||||||
|
@ -830,6 +879,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
return variableNameToShape.get(varName);
|
return variableNameToShape.get(varName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #getShapeForVarName(String)}, but returns the shape descriptor.
|
||||||
|
*/
|
||||||
public LongShapeDescriptor getShapeDescriptorForVarName(String varName) {
|
public LongShapeDescriptor getShapeDescriptorForVarName(String varName) {
|
||||||
if (getVariable(varName).getArr() != null) {
|
if (getVariable(varName).getArr() != null) {
|
||||||
return getVariable(varName).getArr().shapeDescriptor();
|
return getVariable(varName).getArr().shapeDescriptor();
|
||||||
|
@ -861,6 +913,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the shape descriptor for a variable.
|
||||||
|
*/
|
||||||
public void putShapeForVarName(String varName, LongShapeDescriptor shape) {
|
public void putShapeForVarName(String varName, LongShapeDescriptor shape) {
|
||||||
val v = getVariable(varName);
|
val v = getVariable(varName);
|
||||||
putShapeForVarName(varName, shape.getShape());
|
putShapeForVarName(varName, shape.getShape());
|
||||||
|
@ -1559,19 +1614,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the differential function (if any) that this variable is the output for
|
|
||||||
*
|
|
||||||
* @param variableName Name of the variable
|
|
||||||
* @return The differential function that this variable is an output of, or null if it is not the output of a function
|
|
||||||
*/
|
|
||||||
public DifferentialFunction getVariableOutputFunction(String variableName) {
|
|
||||||
Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
|
|
||||||
if (variables.get(variableName).getOutputOfOp() == null)
|
|
||||||
return null;
|
|
||||||
return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if this function already has defined arguments
|
* Returns true if this function already has defined arguments
|
||||||
|
@ -1628,7 +1670,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
*
|
*
|
||||||
* @return Array of differential functions
|
* @return Array of differential functions
|
||||||
*/
|
*/
|
||||||
public DifferentialFunction[] functions() {
|
public DifferentialFunction[] ops() {
|
||||||
List<DifferentialFunction> out = new ArrayList<>(ops.size());
|
List<DifferentialFunction> out = new ArrayList<>(ops.size());
|
||||||
for (SameDiffOp op : ops.values()) {
|
for (SameDiffOp op : ops.values()) {
|
||||||
out.add(op.getOp());
|
out.add(op.getOp());
|
||||||
|
@ -3143,10 +3185,18 @@ public class SameDiff extends SDBaseOps {
|
||||||
placeholders, batch, requiredActivations, activeListeners, at);
|
placeholders, batch, requiredActivations, activeListeners, at);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #one(String, DataType, int...)}.
|
||||||
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
|
*/
|
||||||
public SDVariable one(String name, int... shape) {
|
public SDVariable one(String name, int... shape) {
|
||||||
return one(name, Nd4j.defaultFloatingPointType(), shape);
|
return one(name, Nd4j.defaultFloatingPointType(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #one(String, DataType, long...)}.
|
||||||
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
|
*/
|
||||||
public SDVariable one(String name, long... shape) {
|
public SDVariable one(String name, long... shape) {
|
||||||
return one(name, Nd4j.defaultFloatingPointType(), shape);
|
return one(name, Nd4j.defaultFloatingPointType(), shape);
|
||||||
}
|
}
|
||||||
|
@ -3174,11 +3224,18 @@ public class SameDiff extends SDBaseOps {
|
||||||
return var(name, new ConstantInitScheme('f', 1.0), dataType, shape);
|
return var(name, new ConstantInitScheme('f', 1.0), dataType, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #zero(String, DataType, long...)}.
|
||||||
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
|
*/
|
||||||
public SDVariable zero(String name, long... shape) {
|
public SDVariable zero(String name, long... shape) {
|
||||||
return zero(name, Nd4j.defaultFloatingPointType(), shape);
|
return zero(name, Nd4j.defaultFloatingPointType(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #zero(String, DataType, int...)}.
|
||||||
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
|
*/
|
||||||
public SDVariable zero(String name, int... shape) {
|
public SDVariable zero(String name, int... shape) {
|
||||||
return zero(name, Nd4j.defaultFloatingPointType(), shape);
|
return zero(name, Nd4j.defaultFloatingPointType(), shape);
|
||||||
}
|
}
|
||||||
|
@ -3293,6 +3350,18 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API!
|
//TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API!
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Variable initialization with a specified {@link WeightInitScheme}
|
||||||
|
* This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details.
|
||||||
|
*
|
||||||
|
* @param name the name of the variable
|
||||||
|
* @param variableType the SameDiff variable type of the variable (e.g. CONSTANT, PLACEHOLDER, etc.)
|
||||||
|
* @param weightInitScheme the weight initialization scheme
|
||||||
|
* @param dataType the data type of the variable (float, int, etc)
|
||||||
|
* @param shape the shape of the array to be created
|
||||||
|
* @return the created variable
|
||||||
|
*/
|
||||||
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
||||||
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
||||||
|
|
||||||
|
@ -3932,7 +4001,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param varName the variable name to remove
|
* @param varName the variable name to remove
|
||||||
* @param function the function to remove the argument from
|
* @param function the function to remove the argument from
|
||||||
*/
|
*/
|
||||||
public void removeArgFromFunction(String varName, DifferentialFunction function) {
|
public void removeArgFromOp(String varName, DifferentialFunction function) {
|
||||||
val args = function.args();
|
val args = function.args();
|
||||||
|
|
||||||
for (int i = 0; i < args.length; i++) {
|
for (int i = 0; i < args.length; i++) {
|
||||||
|
@ -4324,7 +4393,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Update the internal state: outgoing variables for function
|
//Update the internal state: outgoing variables for function
|
||||||
if (getOutputsForFunction(function) == null)
|
if (getOutputsForOp(function) == null)
|
||||||
addOutgoingFor(ret, function);
|
addOutgoingFor(ret, function);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -4357,7 +4426,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
|
|
||||||
//Update the internal state: outgoing variables for function
|
//Update the internal state: outgoing variables for function
|
||||||
if (getOutputsForFunction(function) == null)
|
if (getOutputsForOp(function) == null)
|
||||||
addOutgoingFor(ret, function);
|
addOutgoingFor(ret, function);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -4428,7 +4497,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new TensorArray.
|
||||||
|
*/
|
||||||
public TensorArray tensorArray(DataType dataType) {
|
public TensorArray tensorArray(DataType dataType) {
|
||||||
TensorArray ta = new TensorArray(this, dataType);
|
TensorArray ta = new TensorArray(this, dataType);
|
||||||
SDVariable[] outVars = ta.outputVariables();
|
SDVariable[] outVars = ta.outputVariables();
|
||||||
|
@ -4439,7 +4510,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param functionName
|
* @param functionName
|
||||||
* @param with
|
* @param with
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
|
public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
|
||||||
SameDiff instance = sameDiffFunctionInstances.get(functionName);
|
SameDiff instance = sameDiffFunctionInstances.get(functionName);
|
||||||
SDVariable ret = instance.invokeGraphOn(with);
|
SDVariable ret = instance.invokeGraphOn(with);
|
||||||
|
@ -5746,6 +5816,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
return bufferBuilder.dataBuffer();
|
return bufferBuilder.dataBuffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #asFlatGraph(long, ExecutorConfiguration, boolean)}.
|
||||||
|
*
|
||||||
|
* Uses the default {@link ExecutorConfiguration} with output mode as
|
||||||
|
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
|
||||||
|
* with profiling disabled and gather timings enabled.
|
||||||
|
*/
|
||||||
public FlatGraph asFlatGraph(boolean includeUpdaterState) {
|
public FlatGraph asFlatGraph(boolean includeUpdaterState) {
|
||||||
return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState));
|
return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState));
|
||||||
}
|
}
|
||||||
|
@ -5765,6 +5842,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
* This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and
|
* This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and
|
||||||
* all arrays as a ByteBuffer containing the FlatBuffers format data
|
* all arrays as a ByteBuffer containing the FlatBuffers format data
|
||||||
*
|
*
|
||||||
|
* Uses the default {@link ExecutorConfiguration} with output mode as
|
||||||
|
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
|
||||||
|
* with profiling disabled and gather timings enabled.
|
||||||
|
*
|
||||||
* @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)
|
* @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)
|
||||||
* @return a ByteBuffer holding the exported FlatBuffers representation of the graph
|
* @return a ByteBuffer holding the exported FlatBuffers representation of the graph
|
||||||
*/
|
*/
|
||||||
|
@ -5870,7 +5951,11 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later<br>
|
* This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later<br>
|
||||||
* This includes the updater state, if applicable
|
* This includes the updater state, if applicable.
|
||||||
|
*
|
||||||
|
* Uses the default {@link ExecutorConfiguration} with output mode as
|
||||||
|
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
|
||||||
|
* with profiling disabled and gather timings enabled.
|
||||||
*
|
*
|
||||||
* @param file File to save the FlatBuffers serialized graph (including arrays) to
|
* @param file File to save the FlatBuffers serialized graph (including arrays) to
|
||||||
*/
|
*/
|
||||||
|
@ -5878,6 +5963,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
asFlatFile(file, true);
|
asFlatFile(file, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #asFlatFile(File, ExecutorConfiguration, boolean)}.
|
||||||
|
*
|
||||||
|
* Uses the default {@link ExecutorConfiguration} with output mode as
|
||||||
|
* {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL},
|
||||||
|
* with profiling disabled and gather timings enabled.
|
||||||
|
*/
|
||||||
public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException {
|
public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException {
|
||||||
val fb = asFlatBuffers(withUpdaterState);
|
val fb = asFlatBuffers(withUpdaterState);
|
||||||
val offset = fb.position();
|
val offset = fb.position();
|
||||||
|
@ -5943,6 +6035,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
* instance from a byte buffers
|
* instance from a byte buffers
|
||||||
* instance.
|
* instance.
|
||||||
*
|
*
|
||||||
|
* See {@link #fromFlatBuffers(ByteBuffer, boolean)}. Loads updater state (loadUpdaterState is true).
|
||||||
|
*
|
||||||
* @param bbIn the input byte buffer
|
* @param bbIn the input byte buffer
|
||||||
* @return the created samediff instance
|
* @return the created samediff instance
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
|
@ -5951,6 +6045,16 @@ public class SameDiff extends SDBaseOps {
|
||||||
return fromFlatBuffers(bbIn, true);
|
return fromFlatBuffers(bbIn, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link SameDiff}
|
||||||
|
* instance from a byte buffers
|
||||||
|
* instance.
|
||||||
|
*
|
||||||
|
* @param bbIn the input byte buffer
|
||||||
|
* @param loadUpdaterState If true, load the updater state (Adam etc state). For training, use true. For inference, use false
|
||||||
|
* @return the created samediff instance
|
||||||
|
* @throws IOException
|
||||||
|
*/
|
||||||
public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException {
|
public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException {
|
||||||
|
|
||||||
FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn);
|
FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn);
|
||||||
|
@ -6287,7 +6391,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
public String summary() {
|
public String summary() {
|
||||||
|
|
||||||
Map<String, SDVariable> varMap = variableMap();
|
Map<String, SDVariable> varMap = variableMap();
|
||||||
DifferentialFunction[] functions = functions();
|
DifferentialFunction[] functions = ops();
|
||||||
|
|
||||||
|
|
||||||
int countVarsWithArrays = 0;
|
int countVarsWithArrays = 0;
|
||||||
|
@ -6324,7 +6428,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
if (outputOf == null) {
|
if (outputOf == null) {
|
||||||
outputOf = "<none>";
|
outputOf = "<none>";
|
||||||
} else {
|
} else {
|
||||||
DifferentialFunction d = getFunctionById(outputOf);
|
DifferentialFunction d = getOpById(outputOf);
|
||||||
outputOf = d.getOwnName() + "(" + d.opName() + ")";
|
outputOf = d.getOwnName() + "(" + d.opName() + ")";
|
||||||
}
|
}
|
||||||
outputOfFn.put(s, outputOf);
|
outputOfFn.put(s, outputOf);
|
||||||
|
@ -6412,7 +6516,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
for (Map.Entry<String, SameDiff> e : sameDiffFunctionInstances.entrySet()) {
|
for (Map.Entry<String, SameDiff> e : sameDiffFunctionInstances.entrySet()) {
|
||||||
SameDiff sd = e.getValue();
|
SameDiff sd = e.getValue();
|
||||||
int vars = sd.variableMap().size();
|
int vars = sd.variableMap().size();
|
||||||
int fns = (sd.functions() == null ? 0 : sd.functions().length);
|
int fns = (sd.ops() == null ? 0 : sd.ops().length);
|
||||||
int defFns = sd.definedFunctionNames().size();
|
int defFns = sd.definedFunctionNames().size();
|
||||||
|
|
||||||
sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n");
|
sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n");
|
||||||
|
@ -6422,11 +6526,16 @@ public class SameDiff extends SDBaseOps {
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate data types for the variables in the graph
|
||||||
|
*/
|
||||||
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes() {
|
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes() {
|
||||||
return calculateOutputDataTypes(false);
|
return calculateOutputDataTypes(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate data types for the variables in the graph
|
||||||
|
*/
|
||||||
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(boolean dynamicUpdate) {
|
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(boolean dynamicUpdate) {
|
||||||
List<String> allVars = new ArrayList<>(variables.keySet());
|
List<String> allVars = new ArrayList<>(variables.keySet());
|
||||||
DataTypesSession session = new DataTypesSession(this, dynamicUpdate);
|
DataTypesSession session = new DataTypesSession(this, dynamicUpdate);
|
||||||
|
|
|
@ -325,7 +325,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} else if (sameDiff.getVariableOutputFunction(varToExec.getVariable()) != null) {
|
} else if (sameDiff.getVariableOutputOp(varToExec.getVariable()) != null) {
|
||||||
//Variable is the output of an op -> execute op
|
//Variable is the output of an op -> execute op
|
||||||
String opName = sameDiff.getVariables().get(varToExec.getVariable()).getOutputOfOp();
|
String opName = sameDiff.getVariables().get(varToExec.getVariable()).getOutputOfOp();
|
||||||
|
|
||||||
|
@ -336,7 +336,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
|
|
||||||
|
|
||||||
//Post execution: work out what is now available for exec
|
//Post execution: work out what is now available for exec
|
||||||
String[] opOutputVarNames = sameDiff.getFunctionById(opName).outputVariablesNames();
|
String[] opOutputVarNames = sameDiff.getOpById(opName).outputVariablesNames();
|
||||||
|
|
||||||
Preconditions.checkState(opOutputValues.length == opOutputVarNames.length, "Unexpected number of outputs from executed op %s:" +
|
Preconditions.checkState(opOutputValues.length == opOutputVarNames.length, "Unexpected number of outputs from executed op %s:" +
|
||||||
" got %s outputs when %s outputs were expected (%s)", parameterizedOp.getClass().getSimpleName(), opOutputValues.length,
|
" got %s outputs when %s outputs were expected (%s)", parameterizedOp.getClass().getSimpleName(), opOutputValues.length,
|
||||||
|
@ -423,10 +423,10 @@ public abstract class AbstractSession<T, O> {
|
||||||
//Note subgraph initially should include placeholders and constants
|
//Note subgraph initially should include placeholders and constants
|
||||||
while (!processingQueue.isEmpty()) {
|
while (!processingQueue.isEmpty()) {
|
||||||
String varName = processingQueue.remove();
|
String varName = processingQueue.remove();
|
||||||
String opName = (sameDiff.getVariableOutputFunction(varName) == null ? null : sameDiff.getVariableOutputFunction(varName).getOwnName());
|
String opName = (sameDiff.getVariableOutputOp(varName) == null ? null : sameDiff.getVariableOutputOp(varName).getOwnName());
|
||||||
|
|
||||||
if (!subgraph.contains(varName)) {
|
if (!subgraph.contains(varName)) {
|
||||||
String[] opInputs = opName == null ? null : sameDiff.getInputsForFunction(sameDiff.getFunctionById(opName));
|
String[] opInputs = opName == null ? null : sameDiff.getInputsForOp(sameDiff.getOpById(opName));
|
||||||
List<String> controlDeps = sameDiff.getVariables().get(varName).getControlDeps();
|
List<String> controlDeps = sameDiff.getVariables().get(varName).getControlDeps();
|
||||||
int numInputs = (opInputs == null ? 0 : opInputs.length);
|
int numInputs = (opInputs == null ? 0 : opInputs.length);
|
||||||
if (controlDeps != null) {
|
if (controlDeps != null) {
|
||||||
|
@ -457,7 +457,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
|
|
||||||
if (opName != null) {
|
if (opName != null) {
|
||||||
//To execute op - and hence get this variable: need inputs to that op
|
//To execute op - and hence get this variable: need inputs to that op
|
||||||
String[] inputs = sameDiff.getInputsForFunction(sameDiff.getFunctionById(opName));
|
String[] inputs = sameDiff.getInputsForOp(sameDiff.getOpById(opName));
|
||||||
for (String s2 : inputs) {
|
for (String s2 : inputs) {
|
||||||
if (!subgraph.contains(s2)) {
|
if (!subgraph.contains(s2)) {
|
||||||
processingQueue.add(s2);
|
processingQueue.add(s2);
|
||||||
|
@ -501,7 +501,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
if (inputForOps != null) {
|
if (inputForOps != null) {
|
||||||
for (String opName : inputForOps) {
|
for (String opName : inputForOps) {
|
||||||
|
|
||||||
DifferentialFunction fn = sameDiff.getFunctionById(opName);
|
DifferentialFunction fn = sameDiff.getOpById(opName);
|
||||||
if (fn instanceof Merge) {
|
if (fn instanceof Merge) {
|
||||||
//Merge op: available for execution when *any* of its inputs are available. But only mark it for exec once...
|
//Merge op: available for execution when *any* of its inputs are available. But only mark it for exec once...
|
||||||
List<String> opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp();
|
List<String> opOutputs = sameDiff.getOps().get(opName).getOutputsOfOp();
|
||||||
|
@ -888,7 +888,7 @@ public abstract class AbstractSession<T, O> {
|
||||||
//Mark that outVar needs this specific executedVar (i.e., specific frame/iteration)
|
//Mark that outVar needs this specific executedVar (i.e., specific frame/iteration)
|
||||||
//However, in the case of enter nodes, they are available for ALL iterations (used in loop conditions, for example)
|
//However, in the case of enter nodes, they are available for ALL iterations (used in loop conditions, for example)
|
||||||
Variable v = sameDiff.getVariables().get(inputVar.getVariable());
|
Variable v = sameDiff.getVariables().get(inputVar.getVariable());
|
||||||
boolean isEnter = sameDiff.getVariableOutputFunction(v.getVariable().getVarName()) instanceof Enter;
|
boolean isEnter = sameDiff.getVariableOutputOp(v.getVariable().getVarName()) instanceof Enter;
|
||||||
|
|
||||||
if(isEnter){
|
if(isEnter){
|
||||||
VarId iter0 = forVariable;
|
VarId iter0 = forVariable;
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class DataTypesSession extends AbstractSession<DataType, DataTypesSession
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) {
|
public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) {
|
||||||
DifferentialFunction df = sameDiff.getFunctionById(opName);
|
DifferentialFunction df = sameDiff.getOpById(opName);
|
||||||
List<DataType> inputDataTypes = new ArrayList<>();
|
List<DataType> inputDataTypes = new ArrayList<>();
|
||||||
for(SDVariable v : df.args()){
|
for(SDVariable v : df.args()){
|
||||||
DataType dt = v.dataType();
|
DataType dt = v.dataType();
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.internal;
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableMap;
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
|
@ -121,12 +120,12 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
||||||
if(listeners != null && listeners.size() > 0){
|
if(listeners != null && listeners.size() > 0){
|
||||||
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
|
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
|
||||||
|
|
||||||
ImmutableMap.Builder<String, INDArray> namedOutsBuilder = ImmutableMap.builder();
|
Map<String, INDArray> namedOutsBuilder = new HashMap<>();
|
||||||
|
|
||||||
for(int i = 0 ; i < out.length ; i++)
|
for(int i = 0 ; i < out.length ; i++)
|
||||||
namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]);
|
namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]);
|
||||||
|
|
||||||
Map<String, INDArray> namedOuts = namedOutsBuilder.build();
|
Map<String, INDArray> namedOuts = Collections.unmodifiableMap(namedOutsBuilder);
|
||||||
|
|
||||||
for(Listener l : listeners){
|
for(Listener l : listeners){
|
||||||
if(l.isActive(at.operation())) {
|
if(l.isActive(at.operation())) {
|
||||||
|
@ -223,7 +222,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
||||||
//Merge avairable for forward pass when any of its inputs are available. When multiple are available, behaviour
|
//Merge avairable for forward pass when any of its inputs are available. When multiple are available, behaviour
|
||||||
// is undefined
|
// is undefined
|
||||||
Merge m = (Merge) op;
|
Merge m = (Merge) op;
|
||||||
String[] in = sameDiff.getInputsForFunction(op);
|
String[] in = sameDiff.getInputsForOp(op);
|
||||||
for (String s : in) {
|
for (String s : in) {
|
||||||
VarId vid = newVarId(s, outputFrameIter);
|
VarId vid = newVarId(s, outputFrameIter);
|
||||||
if (nodeOutputs.containsKey(vid)) {
|
if (nodeOutputs.containsKey(vid)) {
|
||||||
|
@ -275,10 +274,10 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
||||||
|
|
||||||
Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName());
|
Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName());
|
||||||
|
|
||||||
while(sameDiff.getVariableOutputFunction(inTensorArray.getVarName()) instanceof Enter){
|
while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){
|
||||||
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead
|
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead
|
||||||
//TODO also TensorArrayWrite, scatter, etc??
|
//TODO also TensorArrayWrite, scatter, etc??
|
||||||
inTensorArray = sameDiff.getVariableOutputFunction(inTensorArray.getVarName()).arg();
|
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
|
||||||
v = newVarId(inTensorArray.getVarName(), v.getParentFrame());
|
v = newVarId(inTensorArray.getVarName(), v.getParentFrame());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,10 +299,10 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
||||||
|
|
||||||
Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName());
|
Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName());
|
||||||
|
|
||||||
while(sameDiff.getVariableOutputFunction(inTensorArray.getVarName()) instanceof Enter){
|
while(sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter){
|
||||||
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite
|
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite
|
||||||
//TODO also TensorArrayScatter, etc??
|
//TODO also TensorArrayScatter, etc??
|
||||||
inTensorArray = sameDiff.getVariableOutputFunction(inTensorArray.getVarName()).arg();
|
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
|
||||||
tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame());
|
tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -405,7 +404,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
||||||
//Input 2: The values to scatter
|
//Input 2: The values to scatter
|
||||||
|
|
||||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||||
TensorArray ta = (TensorArray) sameDiff.getVariableOutputFunction(inTensorArray.getVarName());
|
TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName());
|
||||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||||
if(tArr == null && allIterInputs != null){
|
if(tArr == null && allIterInputs != null){
|
||||||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||||
|
@ -526,7 +525,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
||||||
public DifferentialFunction getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
public DifferentialFunction getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||||
Set<String> constAndPhInputs, Map<String,INDArray> placeholderValues) {
|
Set<String> constAndPhInputs, Map<String,INDArray> placeholderValues) {
|
||||||
|
|
||||||
DifferentialFunction df = sameDiff.getFunctionById(opName);
|
DifferentialFunction df = sameDiff.getOpById(opName);
|
||||||
|
|
||||||
//TODO We should clone these ops - probably - as we don't want them shared between threads/sessions!
|
//TODO We should clone these ops - probably - as we don't want them shared between threads/sessions!
|
||||||
//But let's only clone them *once* and cache in inference session - not on every exec
|
//But let's only clone them *once* and cache in inference session - not on every exec
|
||||||
|
|
|
@ -18,7 +18,6 @@ package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -27,7 +26,6 @@ import org.nd4j.autodiff.samediff.ArgumentInterceptor;
|
||||||
import org.nd4j.autodiff.samediff.NameScope;
|
import org.nd4j.autodiff.samediff.NameScope;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffLambda;
|
import org.nd4j.autodiff.samediff.SameDiffLambda;
|
||||||
import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda;
|
import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda;
|
||||||
import org.nd4j.autodiff.samediff.SameDiffSingleLambda;
|
import org.nd4j.autodiff.samediff.SameDiffSingleLambda;
|
||||||
|
@ -3377,7 +3375,7 @@ public abstract class SDBaseOps {
|
||||||
|
|
||||||
for(SameDiffOp op : sd().getOpsInScope(ifScope)) {
|
for(SameDiffOp op : sd().getOpsInScope(ifScope)) {
|
||||||
for(String in : op.getInputsToOp()){
|
for(String in : op.getInputsToOp()){
|
||||||
sd().removeArgFromFunction(in, op.getOp());
|
sd().removeArgFromOp(in, op.getOp());
|
||||||
}
|
}
|
||||||
sd().getOps().remove(op.getName());
|
sd().getOps().remove(op.getName());
|
||||||
}
|
}
|
||||||
|
|
|
@ -385,6 +385,29 @@ public class SDCNN extends SDOps {
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 3D CNN deconvolution operation with or without optional bias
|
||||||
|
*
|
||||||
|
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
||||||
|
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
||||||
|
* @param config Configuration
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
||||||
|
return deconv3d(null, input, weights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 3D CNN deconvolution operation with no bias
|
||||||
|
*
|
||||||
|
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
||||||
|
* @param config Configuration
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
|
||||||
|
return deconv3d(input, weights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convolution 2d layer batch to space operation on 4d input.<br>
|
* Convolution 2d layer batch to space operation on 4d input.<br>
|
||||||
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
|
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
|
||||||
|
|
|
@ -199,7 +199,7 @@ public class LegacyOpMapper {
|
||||||
case 25:
|
case 25:
|
||||||
return Or.class;
|
return Or.class;
|
||||||
case 26:
|
case 26:
|
||||||
return OldAtan2Op.class;
|
throw new UnsupportedOperationException("OldATan2 (op number " + opNum + ") is no longer supported.");
|
||||||
case 27:
|
case 27:
|
||||||
return LogicalOr.class;
|
return LogicalOr.class;
|
||||||
case 28:
|
case 28:
|
||||||
|
@ -243,7 +243,7 @@ public class LegacyOpMapper {
|
||||||
case 18:
|
case 18:
|
||||||
return Floor.class;
|
return Floor.class;
|
||||||
case 20:
|
case 20:
|
||||||
return OldReverse.class;
|
throw new UnsupportedOperationException("OldReverse (op number " + opNum + ") is no longer supported.");
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException("No known transform same op for op number: " + opNum);
|
throw new UnsupportedOperationException("No known transform same op for op number: " + opNum);
|
||||||
}
|
}
|
||||||
|
@ -581,19 +581,19 @@ public class LegacyOpMapper {
|
||||||
public static Class<?> pairwiseOpClass(int opNum){
|
public static Class<?> pairwiseOpClass(int opNum){
|
||||||
switch (opNum){
|
switch (opNum){
|
||||||
case 0:
|
case 0:
|
||||||
return OldAddOp.class;
|
throw new UnsupportedOperationException("OldFModOp (op number " + opNum + ") is no longer supported.");
|
||||||
case 1:
|
case 1:
|
||||||
return CopyOp.class;
|
return CopyOp.class;
|
||||||
case 2:
|
case 2:
|
||||||
return OldDivOp.class;
|
throw new UnsupportedOperationException("OldDivOp (op number " + opNum + ") is no longer supported.");
|
||||||
case 3:
|
case 3:
|
||||||
return OldEqualTo.class;
|
throw new UnsupportedOperationException("OldEqualTo (op number " + opNum + ") is no longer supported.");
|
||||||
case 4:
|
case 4:
|
||||||
return OldGreaterThan.class;
|
throw new UnsupportedOperationException("OldGreaterThan (op number " + opNum + ") is no longer supported.");
|
||||||
case 5:
|
case 5:
|
||||||
return OldLessThan.class;
|
throw new UnsupportedOperationException("OldLessThan (op number " + opNum + ") is no longer supported.");
|
||||||
case 6:
|
case 6:
|
||||||
return OldMulOp.class;
|
throw new UnsupportedOperationException("OldMulOp (op number " + opNum + ") is no longer supported.");
|
||||||
case 7:
|
case 7:
|
||||||
return Pow.class;
|
return Pow.class;
|
||||||
case 8:
|
case 8:
|
||||||
|
@ -603,15 +603,15 @@ public class LegacyOpMapper {
|
||||||
case 10:
|
case 10:
|
||||||
return Eps.class;
|
return Eps.class;
|
||||||
case 11:
|
case 11:
|
||||||
return OldGreaterThanOrEqual.class;
|
throw new UnsupportedOperationException("OldGreaterThanOrEqual (op number " + opNum + ") is no longer supported.");
|
||||||
case 12:
|
case 12:
|
||||||
return OldLessThanOrEqual.class;
|
throw new UnsupportedOperationException("OldLessThanOrEqual (op number " + opNum + ") is no longer supported.");
|
||||||
case 13:
|
case 13:
|
||||||
return OldMax.class;
|
throw new UnsupportedOperationException("OldMax (op number " + opNum + ") is no longer supported.");
|
||||||
case 14:
|
case 14:
|
||||||
return OldMin.class;
|
throw new UnsupportedOperationException("OldMin (op number " + opNum + ") is no longer supported.");
|
||||||
case 15:
|
case 15:
|
||||||
return OldNotEqualTo.class;
|
throw new UnsupportedOperationException("OldNotEqualTo (op number " + opNum + ") is no longer supported.");
|
||||||
case 16:
|
case 16:
|
||||||
return Set.class;
|
return Set.class;
|
||||||
case 17:
|
case 17:
|
||||||
|
@ -631,11 +631,11 @@ public class LegacyOpMapper {
|
||||||
case 59:
|
case 59:
|
||||||
return RemainderOp.class;
|
return RemainderOp.class;
|
||||||
case 60:
|
case 60:
|
||||||
return OldFModOp.class;
|
throw new UnsupportedOperationException("OldFModOp (op number " + opNum + ") is no longer supported.");
|
||||||
case 69:
|
case 69:
|
||||||
return OldAtan2Op.class;
|
throw new UnsupportedOperationException("OldATan2 (op number " + opNum + ") is no longer supported.");
|
||||||
case 20:
|
case 20:
|
||||||
return OldFloorDivOp.class;
|
throw new UnsupportedOperationException("OldFloorDivOp (op number " + opNum + ") is no longer supported.");
|
||||||
case 26:
|
case 26:
|
||||||
return RelativeError.class;
|
return RelativeError.class;
|
||||||
case 27:
|
case 27:
|
||||||
|
|
|
@ -78,7 +78,7 @@ public class GraphTransformUtil {
|
||||||
if (oldInputsForOps != null) {
|
if (oldInputsForOps != null) {
|
||||||
List<String> newInputsForOps = new ArrayList<>();
|
List<String> newInputsForOps = new ArrayList<>();
|
||||||
for (String s : oldInputsForOps) {
|
for (String s : oldInputsForOps) {
|
||||||
DifferentialFunction df = sd.getFunctionById(s);
|
DifferentialFunction df = sd.getOpById(s);
|
||||||
if (!allSubGraphFns.contains(df)) {
|
if (!allSubGraphFns.contains(df)) {
|
||||||
newInputsForOps.add(s);
|
newInputsForOps.add(s);
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,7 @@ public class GraphTransformUtil {
|
||||||
// (1) variable is (was) input to op that has been removed - just remove from list
|
// (1) variable is (was) input to op that has been removed - just remove from list
|
||||||
// (2) variable is now connected directly as an output: (A->B->C) becomes (A->C)
|
// (2) variable is now connected directly as an output: (A->B->C) becomes (A->C)
|
||||||
// For the latter case, this
|
// For the latter case, this
|
||||||
DifferentialFunction df = sd.getFunctionById(opName);
|
DifferentialFunction df = sd.getOpById(opName);
|
||||||
if (allSubGraphFns.contains(df)) {
|
if (allSubGraphFns.contains(df)) {
|
||||||
newInputsForOp.remove(opName);
|
newInputsForOp.remove(opName);
|
||||||
}
|
}
|
||||||
|
@ -178,7 +178,7 @@ public class GraphTransformUtil {
|
||||||
*/
|
*/
|
||||||
public static List<SubGraph> getSubgraphsMatching(SameDiff sd, SubGraphPredicate p) {
|
public static List<SubGraph> getSubgraphsMatching(SameDiff sd, SubGraphPredicate p) {
|
||||||
List<SubGraph> out = new ArrayList<>();
|
List<SubGraph> out = new ArrayList<>();
|
||||||
for (DifferentialFunction df : sd.functions()) {
|
for (DifferentialFunction df : sd.ops()) {
|
||||||
if (p.matches(sd, df)) {
|
if (p.matches(sd, df)) {
|
||||||
SubGraph sg = p.getSubGraph(sd, df);
|
SubGraph sg = p.getSubGraph(sd, df);
|
||||||
out.add(sg);
|
out.add(sg);
|
||||||
|
|
|
@ -20,7 +20,6 @@ import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.apache.commons.lang3.builder.Diff;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -68,7 +67,7 @@ public class SubGraph {
|
||||||
boolean allInSubgraph = true;
|
boolean allInSubgraph = true;
|
||||||
if(inputsFor != null){
|
if(inputsFor != null){
|
||||||
for(String opOwnName : inputsFor) {
|
for(String opOwnName : inputsFor) {
|
||||||
if (!inSubgraph(sameDiff.getFunctionById(opOwnName))){
|
if (!inSubgraph(sameDiff.getOpById(opOwnName))){
|
||||||
allInSubgraph = false;
|
allInSubgraph = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class SubGraphPredicate extends OpPredicate {
|
||||||
}
|
}
|
||||||
|
|
||||||
SDVariable in = inputs[inNum];
|
SDVariable in = inputs[inNum];
|
||||||
DifferentialFunction df = sameDiff.getVariableOutputFunction(in.getVarName());
|
DifferentialFunction df = sameDiff.getVariableOutputOp(in.getVarName());
|
||||||
if (df == null || !e.getValue().matches(sameDiff, df)) {
|
if (df == null || !e.getValue().matches(sameDiff, df)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,7 @@ public class SubGraphPredicate extends OpPredicate {
|
||||||
for(Map.Entry<Integer,OpPredicate> entry : opInputSubgraphPredicates.entrySet()){
|
for(Map.Entry<Integer,OpPredicate> entry : opInputSubgraphPredicates.entrySet()){
|
||||||
OpPredicate p2 = entry.getValue();
|
OpPredicate p2 = entry.getValue();
|
||||||
SDVariable arg = rootFn.arg(entry.getKey());
|
SDVariable arg = rootFn.arg(entry.getKey());
|
||||||
DifferentialFunction df = sd.getVariableOutputFunction(arg.getVarName());
|
DifferentialFunction df = sd.getVariableOutputOp(arg.getVarName());
|
||||||
if(df != null){
|
if(df != null){
|
||||||
childNodes.add(df);
|
childNodes.add(df);
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,6 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener;
|
import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -106,7 +105,7 @@ public class GradCheckUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
Set<String> fnOutputs = new HashSet<>();
|
Set<String> fnOutputs = new HashSet<>();
|
||||||
for(DifferentialFunction f : sd.functions()){
|
for(DifferentialFunction f : sd.ops()){
|
||||||
for(SDVariable s : f.outputVariables()){
|
for(SDVariable s : f.outputVariables()){
|
||||||
fnOutputs.add(s.getVarName());
|
fnOutputs.add(s.getVarName());
|
||||||
}
|
}
|
||||||
|
@ -593,7 +592,7 @@ public class GradCheckUtil {
|
||||||
4. Gradient function: should contain all of the existing functions, and more
|
4. Gradient function: should contain all of the existing functions, and more
|
||||||
*/
|
*/
|
||||||
|
|
||||||
DifferentialFunction[] dfs = sd.functions();
|
DifferentialFunction[] dfs = sd.ops();
|
||||||
List<SDVariable> vars = sd.variables();
|
List<SDVariable> vars = sd.variables();
|
||||||
|
|
||||||
Set<String> varSetStr = new HashSet<>();
|
Set<String> varSetStr = new HashSet<>();
|
||||||
|
@ -661,7 +660,7 @@ public class GradCheckUtil {
|
||||||
|
|
||||||
//Check that all original functions are present in the gradient function
|
//Check that all original functions are present in the gradient function
|
||||||
for(DifferentialFunction dfOrig : dfs){
|
for(DifferentialFunction dfOrig : dfs){
|
||||||
Preconditions.checkNotNull(gradFn.getFunctionById(dfOrig.getOwnName()), "DifferentialFunction " + dfOrig.getOwnName()
|
Preconditions.checkNotNull(gradFn.getOpById(dfOrig.getOwnName()), "DifferentialFunction " + dfOrig.getOwnName()
|
||||||
+ " from original SameDiff instance not present in grad fn");
|
+ " from original SameDiff instance not present in grad fn");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
|
||||||
|
@ -94,7 +93,6 @@ import org.nd4j.linalg.api.ops.random.impl.Linspace;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.function.Function;
|
import org.nd4j.linalg.function.Function;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.tensorflow.framework.OpDef;
|
import org.tensorflow.framework.OpDef;
|
||||||
|
@ -464,7 +462,7 @@ public class OpValidation {
|
||||||
//i.e., don't double count if a SameDiff instance has multiple copies of the same op type
|
//i.e., don't double count if a SameDiff instance has multiple copies of the same op type
|
||||||
|
|
||||||
//Collect coverage information for backprop:
|
//Collect coverage information for backprop:
|
||||||
DifferentialFunction[] functions = sd.functions();
|
DifferentialFunction[] functions = sd.ops();
|
||||||
Set<Class> backpropSeen = new HashSet<>();
|
Set<Class> backpropSeen = new HashSet<>();
|
||||||
for (DifferentialFunction df : functions) {
|
for (DifferentialFunction df : functions) {
|
||||||
backpropSeen.add(df.getClass());
|
backpropSeen.add(df.getClass());
|
||||||
|
@ -481,7 +479,7 @@ public class OpValidation {
|
||||||
if (testCase.fwdTestFns() != null) {
|
if (testCase.fwdTestFns() != null) {
|
||||||
for (String s : testCase.fwdTestFns().keySet()) {
|
for (String s : testCase.fwdTestFns().keySet()) {
|
||||||
//Determine the differential function that this variable is the output of, if any
|
//Determine the differential function that this variable is the output of, if any
|
||||||
DifferentialFunction df = sd.getVariableOutputFunction(s);
|
DifferentialFunction df = sd.getVariableOutputOp(s);
|
||||||
if (df != null) {
|
if (df != null) {
|
||||||
if (seen == null)
|
if (seen == null)
|
||||||
seen = new HashSet<>();
|
seen = new HashSet<>();
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -708,8 +708,8 @@ public class ROC extends BaseEvaluation<ROC> {
|
||||||
itp = isTruePositive;
|
itp = isTruePositive;
|
||||||
ifp = isFalsePositive;
|
ifp = isFalsePositive;
|
||||||
} else {
|
} else {
|
||||||
isTruePositive = Nd4j.getExecutioner().exec(new OldMulOp(predictedClass1, positiveActualClassColumn, itp));
|
isTruePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, positiveActualClassColumn, itp))[0];
|
||||||
isFalsePositive = Nd4j.getExecutioner().exec(new OldMulOp(predictedClass1, negativeActualClassColumn, ifp));
|
isFalsePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, negativeActualClassColumn, ifp))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
//Counts for this batch:
|
//Counts for this batch:
|
||||||
|
|
|
@ -68,17 +68,6 @@ public class DifferentialFunctionClassHolder {
|
||||||
add("sameDiff");
|
add("sameDiff");
|
||||||
add("ownName");
|
add("ownName");
|
||||||
}};
|
}};
|
||||||
private static final Set<String> classesWithConfig = new LinkedHashSet<String>(){{
|
|
||||||
add(AvgPooling2D.class.getName());
|
|
||||||
add(Conv2D.class.getName());
|
|
||||||
add(Conv3D.class.getName());
|
|
||||||
add(LocalResponseNormalization.class.getName());
|
|
||||||
add(MaxPooling2D.class.getName());
|
|
||||||
add(Pooling2D.class.getName());
|
|
||||||
add(Pooling3D.class.getName());
|
|
||||||
add(DepthwiseConv2D.class.getName());
|
|
||||||
add(DeConv2DTF.class.getName());
|
|
||||||
}};
|
|
||||||
//When determining fields/properties, where should we terminate the search?
|
//When determining fields/properties, where should we terminate the search?
|
||||||
//We don't wan to include every single field from every single superclass
|
//We don't wan to include every single field from every single superclass
|
||||||
private static final Set<Class> classesToIgnore = new HashSet<>(Arrays.<Class>asList(
|
private static final Set<Class> classesToIgnore = new HashSet<>(Arrays.<Class>asList(
|
||||||
|
@ -165,16 +154,37 @@ public class DifferentialFunctionClassHolder {
|
||||||
Map<String, Field> fieldNames = new LinkedHashMap<>();
|
Map<String, Field> fieldNames = new LinkedHashMap<>();
|
||||||
Class<? extends DifferentialFunction> current = df.getClass();
|
Class<? extends DifferentialFunction> current = df.getClass();
|
||||||
val fields = new ArrayList<Field>();
|
val fields = new ArrayList<Field>();
|
||||||
|
boolean isFirst = true;
|
||||||
|
|
||||||
while (current.getSuperclass() != null && !classesToIgnore.contains(current.getSuperclass())) {
|
while (current.getSuperclass() != null && !classesToIgnore.contains(current.getSuperclass())) {
|
||||||
if (classesWithConfig.contains(current.getName())) {
|
|
||||||
|
|
||||||
val fieldName = "config";
|
if (df.isConfigProperties() && isFirst) {
|
||||||
|
|
||||||
val configField = current.getDeclaredField(fieldName);
|
String fieldName = df.configFieldName();
|
||||||
if (configField == null) {
|
|
||||||
continue;
|
if(fieldName == null)
|
||||||
|
fieldName = "config";
|
||||||
|
|
||||||
|
Field configField = null;
|
||||||
|
try{
|
||||||
|
configField = current.getDeclaredField(fieldName);
|
||||||
|
} catch (NoSuchFieldException e){
|
||||||
|
Class<?> currentConfig = current.getSuperclass();
|
||||||
|
|
||||||
|
// find a config field in superclasses
|
||||||
|
while(currentConfig.getSuperclass() != null){
|
||||||
|
try {
|
||||||
|
configField = currentConfig.getDeclaredField(fieldName);
|
||||||
|
break;
|
||||||
|
} catch (NoSuchFieldException e2){
|
||||||
|
currentConfig = currentConfig.getSuperclass();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(configField == null)
|
||||||
|
continue;
|
||||||
|
|
||||||
val configFieldClass = configField.getType();
|
val configFieldClass = configField.getType();
|
||||||
|
|
||||||
for (val field : configFieldClass.getDeclaredFields()) {
|
for (val field : configFieldClass.getDeclaredFields()) {
|
||||||
|
@ -206,6 +216,7 @@ public class DifferentialFunctionClassHolder {
|
||||||
|
|
||||||
// do something with current's fields
|
// do something with current's fields
|
||||||
current = (Class<? extends DifferentialFunction>) current.getSuperclass();
|
current = (Class<? extends DifferentialFunction>) current.getSuperclass();
|
||||||
|
isFirst = false;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -347,14 +347,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace.class,
|
org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet.class,
|
org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps.class,
|
org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThan.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldGreaterThanOrEqual.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThanOrEqual.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.comparison.OldNotEqualTo.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.Assign.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class,
|
||||||
|
@ -453,15 +445,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldFModOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldFloorDivOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldRDivOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldRSubOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise.class,
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp.class,
|
||||||
|
@ -493,8 +476,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.Max.class,
|
org.nd4j.linalg.api.ops.impl.transforms.same.Max.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.Min.class,
|
org.nd4j.linalg.api.ops.impl.transforms.same.Min.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.Negative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.same.Negative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.OneMinus.class,
|
org.nd4j.linalg.api.ops.impl.transforms.same.OneMinus.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal.class,
|
org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.same.Round.class,
|
org.nd4j.linalg.api.ops.impl.transforms.same.Round.class,
|
||||||
|
|
|
@ -361,7 +361,7 @@ public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_T
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void initOutputVariables(SameDiff sd, DifferentialFunction df) {
|
protected void initOutputVariables(SameDiff sd, DifferentialFunction df) {
|
||||||
String[] outNames = sd.getOutputsForFunction(df);
|
String[] outNames = sd.getOutputsForOp(df);
|
||||||
SDVariable[] outVars;
|
SDVariable[] outVars;
|
||||||
if (outNames == null) {
|
if (outNames == null) {
|
||||||
outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true);
|
outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true);
|
||||||
|
|
|
@ -409,7 +409,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
newInstance.setSameDiff(importState.getSameDiff());
|
newInstance.setSameDiff(importState.getSameDiff());
|
||||||
|
|
||||||
newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph());
|
newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph());
|
||||||
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(),newInstance);
|
importState.getSameDiff().putOpForId(newInstance.getOwnName(),newInstance);
|
||||||
//ensure we can track node name to function instance later.
|
//ensure we can track node name to function instance later.
|
||||||
diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance);
|
diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance);
|
||||||
//diff.addVarNameForImport(tfNode.getName());
|
//diff.addVarNameForImport(tfNode.getName());
|
||||||
|
|
|
@ -16,14 +16,11 @@
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper.tf;
|
package org.nd4j.imports.graphmapper.tf;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.Descriptors;
|
|
||||||
import com.github.os72.protobuf351.Message;
|
import com.github.os72.protobuf351.Message;
|
||||||
import com.google.common.primitives.Floats;
|
import com.google.common.primitives.Floats;
|
||||||
import com.google.common.primitives.Ints;
|
import com.google.common.primitives.Ints;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.indexer.Bfloat16Indexer;
|
|
||||||
import org.bytedeco.javacpp.indexer.HalfIndexer;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -41,19 +38,14 @@ import org.nd4j.imports.graphmapper.OpImportFilter;
|
||||||
import org.nd4j.imports.graphmapper.OpImportOverride;
|
import org.nd4j.imports.graphmapper.OpImportOverride;
|
||||||
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
|
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
|
||||||
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers;
|
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers;
|
||||||
import org.nd4j.linalg.api.buffer.*;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
|
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
|
||||||
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
import org.tensorflow.framework.*;
|
import org.tensorflow.framework.*;
|
||||||
import org.tensorflow.framework.DataType;
|
import org.tensorflow.framework.DataType;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -661,7 +653,7 @@ public class TFGraphMapper extends BaseGraphMapper<GraphDef,NodeDef,AttrValue,No
|
||||||
|
|
||||||
newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
|
newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
|
||||||
mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
|
mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
|
||||||
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
|
importState.getSameDiff().putOpForId(newInstance.getOwnName(), newInstance);
|
||||||
//ensure we can track node name to function instance later.
|
//ensure we can track node name to function instance later.
|
||||||
diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
|
diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
|
|
@ -61,6 +61,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
||||||
|
@ -1638,7 +1639,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
public INDArray lt(INDArray other) {
|
public INDArray lt(INDArray other) {
|
||||||
validateNumericalArray("less than (lt)", false);
|
validateNumericalArray("less than (lt)", false);
|
||||||
if (Shape.shapeEquals(this.shape(), other.shape())) {
|
if (Shape.shapeEquals(this.shape(), other.shape())) {
|
||||||
return Nd4j.getExecutioner().exec(new OldLessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())));
|
return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
|
||||||
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||||
return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
|
return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
|
||||||
} else
|
} else
|
||||||
|
@ -1655,13 +1656,13 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
@Override
|
@Override
|
||||||
public INDArray neq(INDArray other) {
|
public INDArray neq(INDArray other) {
|
||||||
Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array");
|
Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array");
|
||||||
return Nd4j.getExecutioner().exec(new OldNotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())));
|
return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray eq(INDArray other) {
|
public INDArray eq(INDArray other) {
|
||||||
if (Shape.shapeEquals(this.shape(), other.shape())) {
|
if (Shape.shapeEquals(this.shape(), other.shape())) {
|
||||||
return Nd4j.getExecutioner().exec(new OldEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())));
|
return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
|
||||||
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||||
return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
|
return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
|
||||||
} else
|
} else
|
||||||
|
@ -1672,7 +1673,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
public INDArray gt(INDArray other) {
|
public INDArray gt(INDArray other) {
|
||||||
validateNumericalArray("greater than (gt)", false);
|
validateNumericalArray("greater than (gt)", false);
|
||||||
if (Shape.shapeEquals(this.shape(), other.shape())) {
|
if (Shape.shapeEquals(this.shape(), other.shape())) {
|
||||||
return Nd4j.getExecutioner().exec(new OldGreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())));
|
return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
|
||||||
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
} else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
|
||||||
return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
|
return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
|
||||||
} else
|
} else
|
||||||
|
@ -5989,7 +5990,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
} else {
|
} else {
|
||||||
OldFModOp op = new OldFModOp(this, denominator, result);
|
FModOp op = new FModOp(this, denominator, result);
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -6011,7 +6012,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
@Override
|
@Override
|
||||||
public INDArray fmodi(INDArray denominator) {
|
public INDArray fmodi(INDArray denominator) {
|
||||||
validateNumericalArray("fmodi", false);
|
validateNumericalArray("fmodi", false);
|
||||||
OldFModOp op = new OldFModOp(this, denominator, this);
|
FModOp op = new FModOp(this, denominator, this);
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
|
@ -274,7 +274,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables(String baseName) {
|
public SDVariable[] outputVariables(String baseName) {
|
||||||
if(zVertexId == null) {
|
if(zVertexId == null) {
|
||||||
val outputNames = sameDiff.getOutputsForFunction(this);
|
val outputNames = sameDiff.getOutputsForOp(this);
|
||||||
//no need to dynamically create if already exists
|
//no need to dynamically create if already exists
|
||||||
if(outputNames != null) {
|
if(outputNames != null) {
|
||||||
zVertexId = sameDiff.getVariable(outputNames[0]).getVarName();
|
zVertexId = sameDiff.getVariable(outputNames[0]).getVarName();
|
||||||
|
@ -293,13 +293,13 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
|
|
||||||
sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr);
|
sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr);
|
||||||
z = inputArr;
|
z = inputArr;
|
||||||
if(sameDiff.getOutputsForFunction(this) == null)
|
if(sameDiff.getOutputsForOp(this) == null)
|
||||||
sameDiff.addOutgoingFor(newVars,this);
|
sameDiff.addOutgoingFor(newVars,this);
|
||||||
return newVars;
|
return newVars;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDVariable[] newVars = sameDiff.generateOutputVariableForOp(this, baseName, false);
|
SDVariable[] newVars = sameDiff.generateOutputVariableForOp(this, baseName, false);
|
||||||
if (sameDiff.getOutputsForFunction(this) == null)
|
if (sameDiff.getOutputsForOp(this) == null)
|
||||||
sameDiff.addOutgoingFor(newVars, this);
|
sameDiff.addOutgoingFor(newVars, this);
|
||||||
return newVars;
|
return newVars;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,11 +25,9 @@ import onnx.OnnxProto3;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
@ -208,7 +206,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables(String baseName) {
|
public SDVariable[] outputVariables(String baseName) {
|
||||||
if (this.outputVariables == null) {
|
if (this.outputVariables == null) {
|
||||||
val outputNames = sameDiff.getOutputsForFunction(this);
|
val outputNames = sameDiff.getOutputsForOp(this);
|
||||||
//no need to dynamically create if already exists
|
//no need to dynamically create if already exists
|
||||||
if (outputNames != null) {
|
if (outputNames != null) {
|
||||||
outputVariables = new SDVariable[outputNames.length];
|
outputVariables = new SDVariable[outputNames.length];
|
||||||
|
@ -233,7 +231,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
outputVariables = newVars;
|
outputVariables = newVars;
|
||||||
if (sameDiff.getOutputsForFunction(this) == null)
|
if (sameDiff.getOutputsForOp(this) == null)
|
||||||
sameDiff.addOutgoingFor(outputVariables, this);
|
sameDiff.addOutgoingFor(outputVariables, this);
|
||||||
return newVars;
|
return newVars;
|
||||||
}
|
}
|
||||||
|
@ -524,7 +522,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
throw new ND4JIllegalStateException("Op [" + opName() + "] failure for [" + this.getOwnName() + "]: Number of inputs is invalid for execution. "
|
throw new ND4JIllegalStateException("Op [" + opName() + "] failure for [" + this.getOwnName() + "]: Number of inputs is invalid for execution. "
|
||||||
+ numInputArguments() + " were provided but " + descriptor.getNumInputs() + " are required for execution");
|
+ numInputArguments() + " were provided but " + descriptor.getNumInputs() + " are required for execution");
|
||||||
} else {
|
} else {
|
||||||
String[] inputNames = sameDiff.getInputsForFunction(this);
|
String[] inputNames = sameDiff.getInputsForOp(this);
|
||||||
String[] arrayShapes = new String[inputNames.length];
|
String[] arrayShapes = new String[inputNames.length];
|
||||||
for( int i=0; i<inputNames.length; i++ ){
|
for( int i=0; i<inputNames.length; i++ ){
|
||||||
INDArray arr = sameDiff.getVariable(inputNames[i]).getArr();
|
INDArray arr = sameDiff.getVariable(inputNames[i]).getArr();
|
||||||
|
|
|
@ -107,7 +107,7 @@ public class If extends DifferentialFunction implements CustomOp {
|
||||||
SameDiffFunctionDefinition falseBody) {
|
SameDiffFunctionDefinition falseBody) {
|
||||||
|
|
||||||
this.sameDiff = parent;
|
this.sameDiff = parent;
|
||||||
parent.putFunctionForId(getOwnName(),this);
|
parent.putOpForId(getOwnName(),this);
|
||||||
this.inputVars = inputVars;
|
this.inputVars = inputVars;
|
||||||
this.predicate = predicate;
|
this.predicate = predicate;
|
||||||
|
|
||||||
|
|
|
@ -136,7 +136,7 @@ public class While extends DifferentialFunction implements CustomOp {
|
||||||
this.trueBody = trueBody;
|
this.trueBody = trueBody;
|
||||||
this.blockName = blockName;
|
this.blockName = blockName;
|
||||||
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
||||||
parent.putFunctionForId(getOwnName(),this);
|
parent.putOpForId(getOwnName(),this);
|
||||||
|
|
||||||
parent.addArgsFor(inputVars,this);
|
parent.addArgsFor(inputVars,this);
|
||||||
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
|
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
|
||||||
|
@ -457,7 +457,7 @@ public class While extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
//the output of the condition should always be a singular scalar
|
//the output of the condition should always be a singular scalar
|
||||||
//this is a safe assumption
|
//this is a safe assumption
|
||||||
val conditionVars = scopeCondition.functions();
|
val conditionVars = scopeCondition.ops();
|
||||||
if(conditionVars.length < 1) {
|
if(conditionVars.length < 1) {
|
||||||
throw new ND4JIllegalArgumentException("No functions found!");
|
throw new ND4JIllegalArgumentException("No functions found!");
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,6 @@ import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.OnnxProto3;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -40,7 +39,6 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -71,7 +69,7 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
this.config = config;
|
this.config = config;
|
||||||
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putFunctionForId(this.getOwnName(), this);
|
sameDiff.putOpForId(this.getOwnName(), this);
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
sameDiff.addArgsFor(inputFunctions, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,12 +111,6 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
return config.toProperties();
|
return config.toProperties();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
|
||||||
addArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isConfigProperties() {
|
public boolean isConfigProperties() {
|
||||||
return true;
|
return true;
|
||||||
|
@ -129,107 +121,6 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
return "config";
|
return "config";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
|
||||||
addArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
|
|
||||||
Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
|
|
||||||
Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
|
|
||||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
|
|
||||||
|
|
||||||
|
|
||||||
tfMappings.put("kH", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 2, 0, fields.get("dataFormat")));
|
|
||||||
tfMappings.put("kW", new ConditionalFieldValueNDArrayShapeAdapter("NHW", 3, 1, fields.get("dataFormat")));
|
|
||||||
tfMappings.put("sH", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 2, 1, fields.get("dataFormat")));
|
|
||||||
tfMappings.put("sW", new ConditionalFieldValueIntIndexArrayAdapter("NHW", 3, 2, fields.get("dataFormat")));
|
|
||||||
tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
|
|
||||||
tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
|
|
||||||
|
|
||||||
|
|
||||||
Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
|
|
||||||
onnxMappings.put("kH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
|
|
||||||
onnxMappings.put("kW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
|
|
||||||
onnxMappings.put("dH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
|
|
||||||
onnxMappings.put("dW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
|
|
||||||
onnxMappings.put("sH", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
|
|
||||||
onnxMappings.put("sW", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
|
|
||||||
onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
|
|
||||||
onnxMappings.put("isNHWC", new StringEqualsAdapter("NHC"));
|
|
||||||
|
|
||||||
ret.put(tensorflowName(), tfMappings);
|
|
||||||
ret.put(onnxName(), onnxMappings);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
|
||||||
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
|
|
||||||
Map<String, PropertyMapping> map = new HashMap<>();
|
|
||||||
val strideMapping = PropertyMapping.builder()
|
|
||||||
.tfAttrName("strides")
|
|
||||||
.onnxAttrName("strides")
|
|
||||||
.propertyNames(new String[]{"s"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val kernelMapping = PropertyMapping.builder()
|
|
||||||
.propertyNames(new String[]{"k"})
|
|
||||||
.tfInputPosition(1)
|
|
||||||
.shapePosition(0)
|
|
||||||
.onnxAttrName("kernel_shape")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val paddingMapping = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("padding")
|
|
||||||
.propertyNames(new String[]{"p"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val dataFormat = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("data_format")
|
|
||||||
.tfAttrName("data_format")
|
|
||||||
.propertyNames(new String[]{"dataFormat"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val nhwc = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("data_format")
|
|
||||||
.tfAttrName("data_format")
|
|
||||||
.propertyNames(new String[]{"isNHWC"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val sameMode = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("auto_pad")
|
|
||||||
.propertyNames(new String[]{"isSameMode"})
|
|
||||||
.tfAttrName("padding")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
map.put("s", strideMapping);
|
|
||||||
map.put("k", kernelMapping);
|
|
||||||
map.put("p", paddingMapping);
|
|
||||||
map.put("isSameMode", sameMode);
|
|
||||||
map.put("dataFormat", dataFormat);
|
|
||||||
map.put("isNHWC", nhwc);
|
|
||||||
|
|
||||||
try {
|
|
||||||
ret.put(onnxName(), map);
|
|
||||||
} catch (NoOpNameFoundException e) {
|
|
||||||
//ignore
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
try {
|
|
||||||
ret.put(tensorflowName(), map);
|
|
||||||
} catch (NoOpNameFoundException e) {
|
|
||||||
//ignore
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "conv1d";
|
return "conv1d";
|
||||||
|
@ -241,16 +132,6 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Conv1D";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] tensorflowNames() {
|
|
||||||
return new String[]{"Conv1D"};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
int n = args().length;
|
int n = args().length;
|
||||||
|
|
|
@ -70,7 +70,7 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
config.getSH(), config.getPH(), config.getDW());
|
config.getSH(), config.getPH(), config.getDW());
|
||||||
addArgs();
|
addArgs();
|
||||||
if(sameDiff != null) {
|
if(sameDiff != null) {
|
||||||
sameDiff.putFunctionForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
sameDiff.addArgsFor(inputFunctions, this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ public class DeConv2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putFunctionForId(this.getOwnName(), this);
|
sameDiff.putOpForId(this.getOwnName(), this);
|
||||||
sameDiff.addArgsFor(inputs, this);
|
sameDiff.addArgsFor(inputs, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -71,7 +70,7 @@ public class DeConv2DTF extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putFunctionForId(this.getOwnName(), this);
|
sameDiff.putOpForId(this.getOwnName(), this);
|
||||||
sameDiff.addArgsFor(inputs, this);
|
sameDiff.addArgsFor(inputs, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
||||||
this.sameDiff = sameDiff;
|
this.sameDiff = sameDiff;
|
||||||
this.config = config;
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putFunctionForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
sameDiff.addArgsFor(inputFunctions, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,6 +76,16 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
||||||
addIArgument(config.getDepth());
|
addIArgument(config.getDepth());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isConfigProperties() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String configFieldName(){
|
||||||
|
return "config";
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "lrn";
|
return "lrn";
|
||||||
|
|
|
@ -65,7 +65,7 @@ public class TensorMmul extends DynamicCustomOp {
|
||||||
this.sameDiff = sameDiff;
|
this.sameDiff = sameDiff;
|
||||||
this.mMulTranspose = mMulTranspose;
|
this.mMulTranspose = mMulTranspose;
|
||||||
this.axes = dimensions;
|
this.axes = dimensions;
|
||||||
if(!addedEdges && sameDiff.getOutputsForFunction(this) == null) {
|
if(!addedEdges && sameDiff.getOutputsForOp(this) == null) {
|
||||||
addedEdges = true;
|
addedEdges = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,7 @@ public class Concat extends DynamicCustomOp {
|
||||||
removeInputArgument(inputArgs[inputArguments().length - 1]);
|
removeInputArgument(inputArgs[inputArguments().length - 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
sameDiff.removeArgFromFunction(input,this);
|
sameDiff.removeArgFromOp(input,this);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class TensorArrayConcat extends BaseTensorOp {
|
||||||
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
|
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
|
||||||
//Same output type as the TensorArray - which is defined by input 0
|
//Same output type as the TensorArray - which is defined by input 0
|
||||||
SDVariable tArr = arg(0);
|
SDVariable tArr = arg(0);
|
||||||
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputFunction(tArr.getVarName());
|
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName());
|
||||||
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
|
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
|
||||||
return Collections.singletonList(dt);
|
return Collections.singletonList(dt);
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class TensorArrayGather extends BaseTensorOp {
|
||||||
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
|
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
|
||||||
//Same output type as the TensorArray - which is defined by input 0
|
//Same output type as the TensorArray - which is defined by input 0
|
||||||
SDVariable tArr = arg(0);
|
SDVariable tArr = arg(0);
|
||||||
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputFunction(tArr.getVarName());
|
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName());
|
||||||
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
|
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
|
||||||
return Collections.singletonList(dt);
|
return Collections.singletonList(dt);
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class TensorArrayRead extends BaseTensorOp {
|
||||||
dt = importDataType;
|
dt = importDataType;
|
||||||
} else {
|
} else {
|
||||||
SDVariable tArr = arg(0);
|
SDVariable tArr = arg(0);
|
||||||
DifferentialFunction op = sameDiff.getVariableOutputFunction(tArr.getVarName());
|
DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.getVarName());
|
||||||
TensorArray t3 = (TensorArray) op;
|
TensorArray t3 = (TensorArray) op;
|
||||||
dt = t3.getTensorArrayDataType();
|
dt = t3.getTensorArrayDataType();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms;
|
package org.nd4j.linalg.api.ops.impl.transforms;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDIndex;
|
import org.nd4j.autodiff.samediff.SDIndex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -55,6 +56,14 @@ public class Pad extends DynamicCustomOp {
|
||||||
addTArgument(padValue);
|
addTArgument(padValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){
|
||||||
|
super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out});
|
||||||
|
Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType());
|
||||||
|
this.mode = mode;
|
||||||
|
addIArgument(mode.ordinal());
|
||||||
|
addTArgument(padValue);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "pad";
|
return "pad";
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bit mask over the ndarrays as to whether
|
|
||||||
* the components are equal or not
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldEqualTo extends BaseTransformBoolOp {
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public OldEqualTo(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, Object[] extraArgs) {
|
|
||||||
super(sameDiff, i_v1, i_v2, extraArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldEqualTo(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldEqualTo() {}
|
|
||||||
|
|
||||||
public OldEqualTo(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldEqualTo(INDArray x, INDArray y) {
|
|
||||||
super(x, y, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldeq";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No Tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
//Equals op: 2 inputs, not continuously differentiable but 0s almost everywhere
|
|
||||||
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,91 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bit mask over the ndarrays as to whether
|
|
||||||
* the components are greater than or not
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldGreaterThan extends BaseTransformBoolOp {
|
|
||||||
public OldGreaterThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThan(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThan() {}
|
|
||||||
|
|
||||||
public OldGreaterThan(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThan(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThan(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldgt";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow name found");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return Arrays.asList(outputVariables()[0]);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,85 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bit mask over the ndarrays as to whether
|
|
||||||
* the components are greater than or equal or not
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldGreaterThanOrEqual extends BaseTransformBoolOp {
|
|
||||||
public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThanOrEqual(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThanOrEqual(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThanOrEqual() {}
|
|
||||||
|
|
||||||
public OldGreaterThanOrEqual(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldGreaterThanOrEqual(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldgte";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,91 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bit mask over the ndarrays as to whether
|
|
||||||
* the components are less than or not
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldLessThan extends BaseTransformBoolOp {
|
|
||||||
public OldLessThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThan(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThan(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThan() {}
|
|
||||||
|
|
||||||
public OldLessThan(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThan(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThan(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldlt";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tf opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return Arrays.asList(outputVariables()[0]);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,104 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bit mask over the ndarrays as to whether
|
|
||||||
* the components are less than or equal or not
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldLessThanOrEqual extends BaseTransformBoolOp {
|
|
||||||
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, Object[] extraArgs) {
|
|
||||||
super(sameDiff, i_v1, i_v2, extraArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, long[] shape, boolean inPlace, Object[] extraArgs) {
|
|
||||||
super(sameDiff, i_v, shape, inPlace, extraArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) {
|
|
||||||
super(sameDiff, i_v, extraArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual() {}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLessThanOrEqual(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldlte";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return Arrays.asList(outputVariables()[0]);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Max function
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldMax extends BaseTransformSameOp {
|
|
||||||
public OldMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMax(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMax() {}
|
|
||||||
|
|
||||||
public OldMax(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMax(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMax(INDArray ndArray, INDArray dup) {
|
|
||||||
super(ndArray, dup);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 7;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_max_transform";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Min function
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldMin extends BaseTransformSameOp {
|
|
||||||
public OldMin(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMin(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMin(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMin(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMin() {}
|
|
||||||
|
|
||||||
public OldMin(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMin(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMin(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_min_transform";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("This is not meant to be mapped, use Max instead");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.comparison;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Not equal to function:
|
|
||||||
* Bit mask over whether 2 elements are not equal or not
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldNotEqualTo extends BaseTransformBoolOp {
|
|
||||||
public OldNotEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldNotEqualTo(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldNotEqualTo() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldNotEqualTo(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldNotEqualTo(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldNotEqualTo(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_neq";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No op name found");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
|
|
||||||
return Arrays.asList(f().neg(i_v.get(0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -22,11 +22,13 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Arc Tangent elementwise function
|
* Arc Tangent elementwise function
|
||||||
|
@ -39,6 +41,15 @@ public class ATan2 extends BaseDynamicTransformOp {
|
||||||
super(sameDiff, new SDVariable[] {y, x} ,false);
|
super(sameDiff, new SDVariable[] {y, x} ,false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note that the order of x and y match {@link java.lang.Math#atan2(double, double)},
|
||||||
|
* and are reversed when compared to OldATan2.
|
||||||
|
* See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)}
|
||||||
|
*/
|
||||||
|
public ATan2(INDArray x, INDArray y, INDArray z) {
|
||||||
|
super(new INDArray[]{x, y}, new INDArray[]{ z });
|
||||||
|
}
|
||||||
|
|
||||||
public ATan2() {}
|
public ATan2() {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -46,6 +46,9 @@ public class EqualTo extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public EqualTo(INDArray x, INDArray y, INDArray z){
|
||||||
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
|
|
@ -47,7 +47,9 @@ public class GreaterThan extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public GreaterThan(INDArray x, INDArray y, INDArray z){
|
||||||
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
|
|
@ -46,6 +46,10 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public GreaterThanOrEqual(INDArray x, INDArray y, INDArray z){
|
||||||
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 11;
|
return 11;
|
||||||
|
|
|
@ -47,6 +47,10 @@ public class LessThan extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LessThan(INDArray x, INDArray y, INDArray z){
|
||||||
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "less";
|
return "less";
|
||||||
|
|
|
@ -45,6 +45,11 @@ public class LessThanOrEqual extends BaseDynamicTransformOp {
|
||||||
public LessThanOrEqual( INDArray[] inputs, INDArray[] outputs) {
|
public LessThanOrEqual( INDArray[] inputs, INDArray[] outputs) {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LessThanOrEqual(INDArray x, INDArray y, INDArray z){
|
||||||
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "less_equal";
|
return "less_equal";
|
||||||
|
|
|
@ -46,6 +46,9 @@ public class NotEqualTo extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public NotEqualTo(INDArray x, INDArray y, INDArray z){
|
||||||
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -38,6 +39,27 @@ public class Reverse extends DynamicCustomOp {
|
||||||
public Reverse() {
|
public Reverse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inplace reverse. See {@link #Reverse(INDArray, INDArray)}
|
||||||
|
*/
|
||||||
|
public Reverse(INDArray x){
|
||||||
|
this(x, x);
|
||||||
|
this.inPlace = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverses whole array for compatibility with OldReverse.
|
||||||
|
*
|
||||||
|
* Note that otherwise, passing null or empty dimensions will result in a noop.
|
||||||
|
*/
|
||||||
|
public Reverse(INDArray x, INDArray z){
|
||||||
|
super(new INDArray[]{x}, new INDArray[]{z});
|
||||||
|
this.dimensions = new int[x.rank()];
|
||||||
|
for(int i = 0 ; i < this.dimensions.length ; i++)
|
||||||
|
this.dimensions[i] = i;
|
||||||
|
addIArgument(dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "reverse";
|
return "reverse";
|
||||||
|
|
|
@ -52,6 +52,9 @@ public class FModOp extends BaseTransformSameOp {
|
||||||
public FModOp(INDArray x, INDArray z) {
|
public FModOp(INDArray x, INDArray z) {
|
||||||
super(x, z);
|
super(x, z);
|
||||||
}
|
}
|
||||||
|
public FModOp(INDArray x, INDArray y, INDArray z) {
|
||||||
|
super(x, y, z);
|
||||||
|
}
|
||||||
|
|
||||||
public FModOp(INDArray x) {
|
public FModOp(INDArray x) {
|
||||||
super(x);
|
super(x);
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link AddOp}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public class OldAddOp extends BaseTransformAnyOp {
|
|
||||||
public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAddOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAddOp() {}
|
|
||||||
|
|
||||||
public OldAddOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAddOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAddOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_add";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
|
|
||||||
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
|
|
||||||
|
|
||||||
List<SDVariable> ret = new ArrayList<>(2);
|
|
||||||
ret.add(gradWrtX);
|
|
||||||
ret.add(gradWrtY);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,86 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* atan2 operation
|
|
||||||
*
|
|
||||||
* @author raver119@gmail.com
|
|
||||||
*/
|
|
||||||
public class OldAtan2Op extends BaseTransformAnyOp {
|
|
||||||
public OldAtan2Op(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAtan2Op(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAtan2Op(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAtan2Op(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAtan2Op() {}
|
|
||||||
|
|
||||||
public OldAtan2Op(INDArray x, INDArray y) {
|
|
||||||
super(x, y, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldAtan2Op(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 16;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_atan2";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "ATan2";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link DivOp}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public class OldDivOp extends BaseTransformAnyOp {
|
|
||||||
public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldDivOp() {}
|
|
||||||
|
|
||||||
public OldDivOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldDivOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldDivOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "olddiv";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
|
|
||||||
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
|
|
||||||
List<SDVariable> ret = new ArrayList<>(2);
|
|
||||||
ret.add(gradWrtX);
|
|
||||||
ret.add(gradWrtY);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Floating point remainder
|
|
||||||
*
|
|
||||||
* @author raver119@gmail.com
|
|
||||||
*/
|
|
||||||
public class OldFModOp extends BaseTransformAnyOp {
|
|
||||||
public OldFModOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFModOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFModOp(SameDiff sameDiff) {
|
|
||||||
super(sameDiff);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFModOp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFModOp() {}
|
|
||||||
|
|
||||||
public OldFModOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFModOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
public OldFModOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 15;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldfmod";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,89 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Truncated division operation
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldFloorDivOp extends BaseTransformAnyOp {
|
|
||||||
public OldFloorDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFloorDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFloorDivOp() {}
|
|
||||||
|
|
||||||
public OldFloorDivOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFloorDivOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldFloorDivOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 18;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldfloordiv";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
|
|
||||||
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
|
|
||||||
List<SDVariable> ret = new ArrayList<>(2);
|
|
||||||
ret.add(gradWrtX);
|
|
||||||
ret.add(gradWrtY);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,91 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link MulOp}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public class OldMulOp extends BaseTransformAnyOp {
|
|
||||||
public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMulOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMulOp() {}
|
|
||||||
|
|
||||||
public OldMulOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMulOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldMulOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldmul";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable g = sameDiff.setupFunction(i_v.get(0));
|
|
||||||
SDVariable gradWrtX = f().mul(g,rarg());
|
|
||||||
SDVariable gradWrtY = f().mul(g,larg());
|
|
||||||
List<SDVariable> ret = new ArrayList<>(2);
|
|
||||||
ret.add(gradWrtX);
|
|
||||||
ret.add(gradWrtY);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link RDivOp}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public class OldRDivOp extends BaseTransformAnyOp {
|
|
||||||
public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRDivOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRDivOp() {}
|
|
||||||
|
|
||||||
public OldRDivOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRDivOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRDivOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 11;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "oldrdiv";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable gradWrtX = f().div(i_v.get(0),larg());
|
|
||||||
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(rarg(),larg()));
|
|
||||||
List<SDVariable> ret = new ArrayList<>(2);
|
|
||||||
ret.add(gradWrtX);
|
|
||||||
ret.add(gradWrtY);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link RSubOp}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public class OldRSubOp extends BaseTransformAnyOp {
|
|
||||||
public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRSubOp() {}
|
|
||||||
|
|
||||||
public OldRSubOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRSubOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldRSubOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_rsub";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
|
|
||||||
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
|
|
||||||
|
|
||||||
List<SDVariable> ret = new ArrayList<>(2);
|
|
||||||
ret.add(gradWrtX);
|
|
||||||
ret.add(gradWrtY);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,89 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link SubOp}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public class OldSubOp extends BaseTransformAnyOp {
|
|
||||||
public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSubOp(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSubOp() {}
|
|
||||||
|
|
||||||
public OldSubOp(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSubOp(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSubOp(INDArray x, INDArray y, INDArray z) {
|
|
||||||
super(x, y, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_sub";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable gradWrtX = f().div(i_v.get(0),rarg());
|
|
||||||
SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg()));
|
|
||||||
|
|
||||||
List<SDVariable> ret = new ArrayList<>(2);
|
|
||||||
ret.add(gradWrtX);
|
|
||||||
ret.add(gradWrtY);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
@ -40,6 +41,10 @@ public class Identity extends BaseDynamicTransformOp {
|
||||||
super(sd, new SDVariable[]{input}, false);
|
super(sd, new SDVariable[]{input}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Identity(INDArray x, INDArray z){
|
||||||
|
super(new INDArray[]{x}, new INDArray[]{z});
|
||||||
|
}
|
||||||
|
|
||||||
public Identity(){ }
|
public Identity(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,77 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.same;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Identity function
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OldIdentity extends BaseTransformSameOp {
|
|
||||||
public OldIdentity(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldIdentity() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldIdentity(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldIdentity(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 15;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_identity";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("This op does not work with onnx.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("This op does not work with tensorflow.");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
return i_v;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.same;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* OldReverse op
|
|
||||||
*/
|
|
||||||
public class OldReverse extends BaseTransformSameOp {
|
|
||||||
public OldReverse(SameDiff sameDiff, SDVariable i_v, int... dimensions) {
|
|
||||||
super(sameDiff, i_v, false);
|
|
||||||
this.dimensions = dimensions;
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldReverse() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldReverse(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldReverse(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 20;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_reverse";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
SDVariable ret = f().reverse(f1.get(0), dimensions);
|
|
||||||
return Arrays.asList(ret);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.convolution;
|
package org.nd4j.linalg.convolution;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.Pad.Mode;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -129,8 +130,7 @@ public class OldConvolution {
|
||||||
long w = img.size(3);
|
long w = img.size(3);
|
||||||
long outHeight = outSize(h, kh, sy, ph, coverAll);
|
long outHeight = outSize(h, kh, sy, ph, coverAll);
|
||||||
long outWidth = outSize(w, kw, sx, pw, coverAll);
|
long outWidth = outSize(w, kw, sx, pw, coverAll);
|
||||||
INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}},
|
INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}, Mode.CONSTANT, pval);
|
||||||
Nd4j.PadMode.CONSTANT);
|
|
||||||
INDArray ret = Nd4j.create(n, c, kh, kw, outHeight, outWidth);
|
INDArray ret = Nd4j.create(n, c, kh, kw, outHeight, outWidth);
|
||||||
for (int i = 0; i < kh; i++) {
|
for (int i = 0; i < kh; i++) {
|
||||||
//offset for the row based on the stride and output height
|
//offset for the row based on the stride and output height
|
||||||
|
|
|
@ -35,7 +35,8 @@ public class OmpNumThreadsAction implements EnvironmentalAction {
|
||||||
val skipper = System.getenv(ND4JEnvironmentVars.ND4J_SKIP_BLAS_THREADS);
|
val skipper = System.getenv(ND4JEnvironmentVars.ND4J_SKIP_BLAS_THREADS);
|
||||||
if (skipper == null) {
|
if (skipper == null) {
|
||||||
// we infer num threads only if skipper undefined
|
// we infer num threads only if skipper undefined
|
||||||
Nd4j.setNumThreads(v);
|
// Nd4j.setNumThreads(v);
|
||||||
|
// method does not do anything anymore and was removed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,14 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.*;
|
import org.nd4j.linalg.api.ops.impl.broadcast.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.*;
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.*;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.AMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.AMax;
|
||||||
|
@ -42,7 +50,7 @@ public class Broadcast {
|
||||||
public static INDArray add(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray add(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldAddOp(x,y,z));
|
return Nd4j.getExecutioner().exec(new AddOp(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastAddOp(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastAddOp(x,y,z,dimensions));
|
||||||
|
@ -66,7 +74,7 @@ public class Broadcast {
|
||||||
public static INDArray div(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray div(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldDivOp(x,y,z));
|
return Nd4j.getExecutioner().exec(new DivOp(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastDivOp(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastDivOp(x,y,z,dimensions));
|
||||||
|
@ -78,7 +86,7 @@ public class Broadcast {
|
||||||
public static INDArray eq(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray eq(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldEqualTo(x,y,z));
|
return Nd4j.getExecutioner().exec(new EqualTo(x,y,z))[0];
|
||||||
}
|
}
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastEqualTo(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastEqualTo(x,y,z,dimensions));
|
||||||
}
|
}
|
||||||
|
@ -89,7 +97,7 @@ public class Broadcast {
|
||||||
public static INDArray gt(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray gt(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldGreaterThan(x,y,z));
|
return Nd4j.getExecutioner().exec(new GreaterThan(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastGreaterThan(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastGreaterThan(x,y,z,dimensions));
|
||||||
|
@ -101,7 +109,7 @@ public class Broadcast {
|
||||||
public static INDArray gte(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray gte(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldGreaterThanOrEqual(x,y,z));
|
return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastGreaterThanOrEqual(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastGreaterThanOrEqual(x,y,z,dimensions));
|
||||||
|
@ -113,7 +121,7 @@ public class Broadcast {
|
||||||
public static INDArray lt(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray lt(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldLessThan(x,y,z));
|
return Nd4j.getExecutioner().exec(new LessThan(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastLessThan(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastLessThan(x,y,z,dimensions));
|
||||||
|
@ -125,7 +133,7 @@ public class Broadcast {
|
||||||
public static INDArray lte(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray lte(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldLessThanOrEqual(x,y,z));
|
return Nd4j.getExecutioner().exec(new LessThanOrEqual(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastLessThanOrEqual(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastLessThanOrEqual(x,y,z,dimensions));
|
||||||
|
@ -137,7 +145,7 @@ public class Broadcast {
|
||||||
public static INDArray mul(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray mul(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldMulOp(x,y,z));
|
return Nd4j.getExecutioner().exec(new MulOp(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastMulOp(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastMulOp(x,y,z,dimensions));
|
||||||
|
@ -149,7 +157,7 @@ public class Broadcast {
|
||||||
public static INDArray neq(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray neq(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldNotEqualTo(x,y,z));
|
return Nd4j.getExecutioner().exec(new NotEqualTo(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastNotEqual(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastNotEqual(x,y,z,dimensions));
|
||||||
|
@ -161,7 +169,7 @@ public class Broadcast {
|
||||||
public static INDArray rdiv(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray rdiv(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldRDivOp(x,y,z));
|
return Nd4j.getExecutioner().exec(new RDivOp(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastRDivOp(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastRDivOp(x,y,z,dimensions));
|
||||||
|
@ -173,7 +181,7 @@ public class Broadcast {
|
||||||
public static INDArray rsub(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray rsub(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldSubOp(x,y,z));
|
return Nd4j.getExecutioner().exec(new SubOp(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastRSubOp(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastRSubOp(x,y,z,dimensions));
|
||||||
|
@ -185,7 +193,7 @@ public class Broadcast {
|
||||||
public static INDArray sub(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray sub(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldSubOp(x,y,z));
|
return Nd4j.getExecutioner().exec(new SubOp(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new BroadcastSubOp(x,y,z,dimensions));
|
return Nd4j.getExecutioner().exec(new BroadcastSubOp(x,y,z,dimensions));
|
||||||
|
@ -197,7 +205,7 @@ public class Broadcast {
|
||||||
public static INDArray max(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray max(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldMax(x,y,z));
|
return Nd4j.getExecutioner().exec(new Max(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,7 +218,7 @@ public class Broadcast {
|
||||||
public static INDArray min(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
public static INDArray min(INDArray x, INDArray y, INDArray z, int... dimensions) {
|
||||||
if(dimensions == null || dimensions.length == 0) {
|
if(dimensions == null || dimensions.length == 0) {
|
||||||
validateShapesNoDimCase(x,y,z);
|
validateShapesNoDimCase(x,y,z);
|
||||||
return Nd4j.getExecutioner().exec(new OldMin(x,y,z));
|
return Nd4j.getExecutioner().exec(new Min(x,y,z))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,8 +57,10 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Diag;
|
import org.nd4j.linalg.api.ops.impl.shape.Diag;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Stack;
|
import org.nd4j.linalg.api.ops.impl.shape.Stack;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.Pad.Mode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Tile;
|
import org.nd4j.linalg.api.ops.impl.shape.Tile;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
|
|
||||||
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.*;
|
import org.nd4j.linalg.api.ops.random.impl.*;
|
||||||
import org.nd4j.linalg.api.rng.DefaultRandom;
|
import org.nd4j.linalg.api.rng.DefaultRandom;
|
||||||
|
@ -182,93 +184,69 @@ public class Nd4j {
|
||||||
nd4j.initContext();
|
nd4j.initContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum PadMode {
|
|
||||||
CONSTANT, EDGE, LINEAR_RAMP, MAXIMUM, MEAN, MEDIAN, MINIMUM, REFLECT, SYMMETRIC, WRAP
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #pad(INDArray, int[][], List, PadMode)} with zero padding. (zeros for constantValues).
|
* See {@link #pad(INDArray, INDArray)}. Uses 0 padding.
|
||||||
*/
|
*/
|
||||||
public static INDArray pad(INDArray toPad, int[][] padWidth, PadMode padMode) {
|
public static INDArray pad(@NonNull INDArray toPad, @NonNull int[][] padWidth){
|
||||||
return pad(toPad, padWidth, ArrayUtil.zerosMatrix(toPad.shape()), padMode);
|
return pad(toPad, Nd4j.createFromArray(padWidth));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pad the given ndarray to the size along each dimension
|
* See {@link #pad(INDArray, INDArray)}. Uses 0 padding, and uses padWidth for all dimensions.
|
||||||
|
*/
|
||||||
|
public static INDArray pad(@NonNull INDArray toPad, @NonNull int... padWidth){
|
||||||
|
return pad(toPad, padWidth, Mode.CONSTANT, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #pad(INDArray, INDArray, Pad.Mode, double)} with zero padding (zeros for padValue).
|
||||||
|
*/
|
||||||
|
public static INDArray pad(INDArray toPad, INDArray padding) {
|
||||||
|
return pad(toPad, padding, Mode.CONSTANT, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #pad(INDArray, INDArray, Mode, double)}.
|
||||||
|
*/
|
||||||
|
public static INDArray pad(@NonNull INDArray toPad, @NonNull int[][] padWidth, @NonNull Pad.Mode padMode, double padValue){
|
||||||
|
return pad(toPad, Nd4j.createFromArray(padWidth), padMode, padValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #pad(INDArray, INDArray, Mode, double)}, uses padWidth for all dimensions.
|
||||||
|
*/
|
||||||
|
public static INDArray pad(@NonNull INDArray toPad, @NonNull int[] padWidth, @NonNull Pad.Mode padMode, double padValue){
|
||||||
|
int[][] pads = new int[toPad.rank()][padWidth.length];
|
||||||
|
for(int i = 0 ; i < pads.length ; i++){
|
||||||
|
pads[i] = padWidth;
|
||||||
|
}
|
||||||
|
return pad(toPad, pads, padMode, padValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pad the given ndarray to the size along each dimension.
|
||||||
|
*
|
||||||
* @param toPad the ndarray to pad
|
* @param toPad the ndarray to pad
|
||||||
* @param padWidth the width to pad along each dimension
|
* @param padWidth the width to pad along each dimension
|
||||||
* @param constantValues the values to append for each dimension
|
|
||||||
* @param padMode the mode to pad in
|
* @param padMode the mode to pad in
|
||||||
|
* @param padValue the value used during padding. Only used when padMode is {@link Pad.Mode#CONSTANT}.
|
||||||
* @return the padded ndarray
|
* @return the padded ndarray
|
||||||
* based on the specified mode
|
* based on the specified mode
|
||||||
*/
|
*/
|
||||||
public static INDArray pad(INDArray toPad, int[][] padWidth, List<double[]> constantValues, PadMode padMode) {
|
public static INDArray pad(@NonNull INDArray toPad, @NonNull INDArray padWidth, @NonNull Pad.Mode padMode, double padValue) {
|
||||||
if (padMode == PadMode.CONSTANT) {
|
|
||||||
if (padWidth.length < toPad.rank())
|
|
||||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
|
||||||
|
|
||||||
List<int[]> sizes = new ArrayList<>();
|
Preconditions.checkArgument(toPad.rank() == padWidth.size(0),
|
||||||
for (int i = 0; i < toPad.rank(); i++) {
|
"Must provide padding values for each dimension. Expected %s pairs for a rank %s array, got %s",
|
||||||
sizes.add(padWidth[i]);
|
toPad.rank(), toPad.rank(), padWidth.size(0));
|
||||||
}
|
|
||||||
|
|
||||||
return padImpl(toPad, sizes, constantValues);
|
long[] newShape = new long[toPad.rank()];
|
||||||
|
for(int i = 0 ; i < newShape.length ; i++){
|
||||||
|
newShape[i] = toPad.size(i) + padWidth.getRow(i).sumNumber().intValue();
|
||||||
}
|
}
|
||||||
throw new UnsupportedOperationException();
|
INDArray out = Nd4j.createUninitialized(toPad.dataType(), newShape);
|
||||||
}
|
Pad op = new Pad(toPad, padWidth, out, padMode, padValue);
|
||||||
|
|
||||||
/**
|
return Nd4j.getExecutioner().exec(op)[0];
|
||||||
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth.
|
|
||||||
*/
|
|
||||||
public static INDArray pad(INDArray toPad, int[] padWidth, List<double[]> constantValues, PadMode padMode) {
|
|
||||||
if (padMode == PadMode.CONSTANT) {
|
|
||||||
if (padWidth.length < toPad.rank())
|
|
||||||
throw new IllegalArgumentException("Please specify a pad width for each dimension");
|
|
||||||
|
|
||||||
toPad = Nd4j.stripOnes(toPad);
|
|
||||||
|
|
||||||
List<int[]> sizes = new ArrayList<>();
|
|
||||||
for (int i = 0; i < toPad.rank(); i++) {
|
|
||||||
sizes.add(padWidth);
|
|
||||||
}
|
|
||||||
|
|
||||||
return padImpl(toPad, sizes, constantValues);
|
|
||||||
}
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
// common code for pad(INDArray, int[], List<double[]>, PadMode) and
|
|
||||||
// pad(INDArray, int[][], List<double[]>, PadMode)
|
|
||||||
private static INDArray padImpl(INDArray toPad, List<int[]> sizes, List<double[]> constantValues){
|
|
||||||
|
|
||||||
INDArray ret = toPad;
|
|
||||||
for (int i = 0; i < toPad.rank(); i++) {
|
|
||||||
int[] pad = sizes.get(i);
|
|
||||||
double[] constant = constantValues.get(i);
|
|
||||||
int padBefore = pad[0];
|
|
||||||
int padAfter = pad[1];
|
|
||||||
if (constant.length < 2) {
|
|
||||||
double val = constant[0];
|
|
||||||
constant = new double[2];
|
|
||||||
constant[0] = val;
|
|
||||||
constant[1] = val;
|
|
||||||
}
|
|
||||||
|
|
||||||
double beforeVal = constant[0];
|
|
||||||
double afterVal = constant[1];
|
|
||||||
ret = Nd4j.prepend(ret, padBefore, beforeVal, i);
|
|
||||||
ret = Nd4j.append(ret, padAfter, afterVal, i);
|
|
||||||
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* See {@link #pad(INDArray, int[][], List, PadMode)} with a 1D int[] for padWidth and zero padding.
|
|
||||||
*/
|
|
||||||
public static INDArray pad(INDArray toPad, int[] padWidth, PadMode padMode) {
|
|
||||||
return pad(toPad, padWidth, ArrayUtil.zerosMatrix(padWidth), padMode);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2639,7 +2617,7 @@ public class Nd4j {
|
||||||
* @return the reversed matrix
|
* @return the reversed matrix
|
||||||
*/
|
*/
|
||||||
public static INDArray reverse(INDArray reverse) {
|
public static INDArray reverse(INDArray reverse) {
|
||||||
return Nd4j.getExecutioner().exec(new OldReverse(reverse));
|
return Nd4j.getExecutioner().exec(new Reverse(reverse))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5961,28 +5939,7 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This method returns maximal allowed number of threads for Nd4j.
|
|
||||||
* If value wasn't set in advance, max(1, availableProcessor) will be returned
|
|
||||||
* @return maximal allowed number of threads
|
|
||||||
*/
|
|
||||||
public static int numThreads() {
|
|
||||||
val v = numThreads.get();
|
|
||||||
if (v <= 0)
|
|
||||||
return Math.max(1, Runtime.getRuntime().availableProcessors() / 2);
|
|
||||||
else
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method sets maximal allowed number of threads for Nd4j
|
|
||||||
* @param numthreads maximal allowed number of threads
|
|
||||||
*/
|
|
||||||
public static void setNumThreads(int numthreads) {
|
|
||||||
numThreads.set(numthreads);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static DataType defaultFloatingPointType() {
|
public static DataType defaultFloatingPointType() {
|
||||||
return defaultFloatingPointDataType.get();
|
return defaultFloatingPointDataType.get();
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Data;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -104,7 +104,7 @@ public class AdaMaxUpdater implements GradientUpdater<AdaMax> {
|
||||||
//u = max(B_2 * u, |grad|)
|
//u = max(B_2 * u, |grad|)
|
||||||
u.muli(config.getBeta2());
|
u.muli(config.getBeta2());
|
||||||
Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later
|
Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later
|
||||||
Nd4j.getExecutioner().exec(new OldMax(u, gradient, u));
|
Nd4j.getExecutioner().exec(new Max(u, gradient, u));
|
||||||
|
|
||||||
double beta1t = FastMath.pow(config.getBeta1(), iteration + 1);
|
double beta1t = FastMath.pow(config.getBeta1(), iteration + 1);
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.learning;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
@ -105,6 +105,6 @@ public class NesterovsUpdater implements GradientUpdater<Nesterovs> {
|
||||||
INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1));
|
INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1));
|
||||||
gradient.assign(ret);
|
gradient.assign(ret);
|
||||||
*/
|
*/
|
||||||
Nd4j.getExecutioner().exec(new OldAddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient));
|
Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNot;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
|
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
||||||
|
@ -37,7 +41,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
|
||||||
|
@ -104,7 +107,7 @@ public class Transforms {
|
||||||
|
|
||||||
|
|
||||||
public static INDArray reverse(INDArray x, boolean dup) {
|
public static INDArray reverse(INDArray x, boolean dup) {
|
||||||
return Nd4j.getExecutioner().exec(new OldReverse(x, dup ? x.ulike() : x));
|
return Nd4j.getExecutioner().exec(new Reverse(x, dup ? x.ulike() : x))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -140,14 +143,15 @@ public class Transforms {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Atan2 operation, new INDArray instance will be returned
|
* Atan2 operation, new INDArray instance will be returned
|
||||||
* Note the order of x and y parameters is opposite to that of java.lang.Math.atan2
|
* Note the order of x and y parameters is opposite to that of {@link java.lang.Math#atan2(double, double)}
|
||||||
*
|
*
|
||||||
* @param x the abscissa coordinate
|
* @param x the abscissa coordinate
|
||||||
* @param y the ordinate coordinate
|
* @param y the ordinate coordinate
|
||||||
* @return the theta from point (r, theta) when converting (x,y) from to cartesian to polar coordinates
|
* @return the theta from point (r, theta) when converting (x,y) from to cartesian to polar coordinates
|
||||||
*/
|
*/
|
||||||
public static INDArray atan2(@NonNull INDArray x, @NonNull INDArray y) {
|
public static INDArray atan2(@NonNull INDArray x, @NonNull INDArray y) {
|
||||||
return Nd4j.getExecutioner().exec(new OldAtan2Op(x, y, x.ulike()));
|
// Switched on purpose, to match OldATan2 (which the javadoc was written for)
|
||||||
|
return Nd4j.getExecutioner().exec(new ATan2(y, x, x.ulike()))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -789,7 +793,7 @@ public class Transforms {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
|
public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
|
||||||
return exec(new OldLessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())));
|
return Nd4j.getExecutioner().exec(new LessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -801,7 +805,7 @@ public class Transforms {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
|
public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray, boolean dup) {
|
||||||
return exec(new OldGreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())));
|
return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -986,7 +990,7 @@ public class Transforms {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray identity(INDArray ndArray, boolean dup) {
|
public static INDArray identity(INDArray ndArray, boolean dup) {
|
||||||
return exec(dup ? new OldIdentity(ndArray, ndArray.ulike()) : new OldIdentity(ndArray));
|
return Nd4j.getExecutioner().exec(dup ? new Identity(ndArray, ndArray.ulike()) : new Identity(ndArray, ndArray))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray isMax(INDArray input, DataType dataType) {
|
public static INDArray isMax(INDArray input, DataType dataType) {
|
||||||
|
|
|
@ -16,6 +16,13 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.opvalidation;
|
package org.nd4j.autodiff.opvalidation;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertNull;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
|
@ -32,21 +39,20 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertNull;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LayerOpValidation extends BaseOpValidation {
|
public class LayerOpValidation extends BaseOpValidation {
|
||||||
public LayerOpValidation(Nd4jBackend backend) {
|
public LayerOpValidation(Nd4jBackend backend) {
|
||||||
|
@ -311,7 +317,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable loss = sd.mean("loss", out);
|
SDVariable loss = sd.mean("loss", out);
|
||||||
|
|
||||||
log.info("Starting test: " + msg);
|
log.info("Starting test: " + msg);
|
||||||
TestCase tc = new TestCase(sd);
|
TestCase tc = new TestCase(sd).gradientCheck(true);
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if (error != null) {
|
if (error != null) {
|
||||||
failed.add(msg);
|
failed.add(msg);
|
||||||
|
@ -344,7 +350,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
String msg = Arrays.toString(inSizeNCHW);
|
String msg = Arrays.toString(inSizeNCHW);
|
||||||
|
|
||||||
TestCase tc = new TestCase(sd).testName(msg);
|
TestCase tc = new TestCase(sd).gradientCheck(true).testName(msg);
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if (error != null) {
|
if (error != null) {
|
||||||
failed.add(msg);
|
failed.add(msg);
|
||||||
|
@ -552,7 +558,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable loss = sd.standardDeviation("loss", out, true);
|
SDVariable loss = sd.standardDeviation("loss", out, true);
|
||||||
|
|
||||||
log.info("Starting test: " + msg);
|
log.info("Starting test: " + msg);
|
||||||
TestCase tc = new TestCase(sd);
|
TestCase tc = new TestCase(sd).gradientCheck(true);
|
||||||
tc.testName(msg);
|
tc.testName(msg);
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if (error != null) {
|
if (error != null) {
|
||||||
|
@ -660,7 +666,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
// System.out.println(sd.getFunction("grad").summary());
|
// System.out.println(sd.getFunction("grad").summary());
|
||||||
|
|
||||||
//Gradient check:
|
//Gradient check:
|
||||||
TestCase tc = new TestCase(sd);
|
TestCase tc = new TestCase(sd).gradientCheck(true);
|
||||||
String err = OpValidation.validate(tc);
|
String err = OpValidation.validate(tc);
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
@ -705,7 +711,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
SDVariable loss = out.std(true);
|
SDVariable loss = out.std(true);
|
||||||
//Gradient check:
|
//Gradient check:
|
||||||
TestCase tc = new TestCase(sd);
|
TestCase tc = new TestCase(sd).gradientCheck(true);
|
||||||
String err = OpValidation.validate(tc);
|
String err = OpValidation.validate(tc);
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
@ -798,7 +804,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
exp.putScalar(next, max);
|
exp.putScalar(next, max);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNull(OpValidation.validate(new TestCase(sd)
|
assertNull(OpValidation.validate(new TestCase(sd).gradientCheck(true)
|
||||||
.expected(outPool, exp)));
|
.expected(outPool, exp)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -856,7 +862,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNull(OpValidation.validate(new TestCase(sd)
|
assertNull(OpValidation.validate(new TestCase(sd)
|
||||||
.expected(outPool, exp)));
|
.expected(outPool, exp).gradientCheck(true)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -887,16 +893,12 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig);
|
SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig);
|
||||||
out = sd.nn().tanh("out", out);
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
INDArray outArr = sd.execAndEndResult();
|
|
||||||
val outShape = outArr.shape();
|
|
||||||
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
||||||
assertArrayEquals(new long[]{mb, nIn, 4, 4, 4}, outShape);
|
INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L);
|
||||||
|
|
||||||
SDVariable loss = out.std(true);
|
TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true);
|
||||||
//Gradient check:
|
|
||||||
TestCase tc = new TestCase(sd);
|
|
||||||
String err = OpValidation.validate(tc);
|
String err = OpValidation.validate(tc);
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
@ -927,12 +929,16 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig);
|
SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig);
|
||||||
out = sd.nn().tanh("out", out);
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
INDArray outArr = sd.execAndEndResult();
|
|
||||||
val outShape = outArr.shape();
|
|
||||||
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
||||||
assertArrayEquals(new long[]{mb, nIn, 27, 27, 27}, outShape);
|
INDArray outArr = Nd4j.createFromArray(mb, nIn, 27, 27, 27L);
|
||||||
|
|
||||||
|
TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true);
|
||||||
|
String err = OpValidation.validate(tc);
|
||||||
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -958,13 +964,58 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
||||||
out = sd.nn().tanh("out", out);
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
INDArray outArr = sd.execAndEndResult();
|
|
||||||
INDArray iOut = out.getArr();
|
|
||||||
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
||||||
val outShape = outArr.shape();
|
INDArray outArr = Nd4j.createFromArray(mb, nOut, 27L);
|
||||||
assertArrayEquals(new long[]{mb, nOut, 27}, outShape);
|
TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(false);
|
||||||
|
String err = OpValidation
|
||||||
|
.validate(tc);
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv1dForward(){
|
||||||
|
int nIn = 2;
|
||||||
|
int nOut = 1;
|
||||||
|
int kernel = 3;
|
||||||
|
int batchSize = 10;
|
||||||
|
int sequenceSize = 5;
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
INDArray inArr = Nd4j.linspace(0, nIn * batchSize * sequenceSize, nIn * batchSize * sequenceSize)
|
||||||
|
.reshape(batchSize, nIn, sequenceSize);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.linspace(0, kernel * nIn * nOut, kernel * nIn * nOut)
|
||||||
|
.reshape(kernel, nIn, nOut);
|
||||||
|
|
||||||
|
SDVariable in = sd.var("in", inArr);
|
||||||
|
SDVariable w = sd.var("w", wArr);
|
||||||
|
|
||||||
|
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build());
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(
|
||||||
|
new double[][][]{
|
||||||
|
{{82.42424f, 100.60606f, 118.78788f}},
|
||||||
|
{{264.2424f, 282.4242f, 300.6061f}},
|
||||||
|
{{446.0606f, 464.2424f, 482.424f}},
|
||||||
|
{{627.8788f, 646.0606f, 664.2424f}},
|
||||||
|
{{809.6970f, 827.8788f, 846.0606f}},
|
||||||
|
{{991.5152f, 1009.69696f, 1027.8788f}},
|
||||||
|
{{1173.3333f, 1191.5152f, 1209.6970f}},
|
||||||
|
{{1355.1515f, 1373.3333f, 1391.5153f}},
|
||||||
|
{{1536.9697f, 1555.1515f, 1573.3333f}},
|
||||||
|
{{1718.7878f, 1736.9697f, 1755.1515f}}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.getVarName(), expected);
|
||||||
|
String err = OpValidation.validate(tc);
|
||||||
|
|
||||||
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1000,17 +1051,61 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig);
|
SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig);
|
||||||
out = sd.nn().tanh("out", out);
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
INDArray outArr = sd.execAndEndResult();
|
|
||||||
//Expected output size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27
|
//Expected output size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27
|
||||||
//Expected output size, WITH same mode: out = in/stride
|
//Expected output size, WITH same mode: out = in/stride
|
||||||
val outShape = outArr.shape();
|
INDArray outArr = Nd4j.createFromArray(mb, nOut, 5, 5, 5L);
|
||||||
assertArrayEquals(new long[]{mb, nOut, 5, 5, 5}, outShape);
|
|
||||||
|
|
||||||
SDVariable loss = out.std(true);
|
TestCase tc = new TestCase(sd).expectedOutput("out", outArr).gradientCheck(true);
|
||||||
//Gradient check:
|
String err = OpValidation
|
||||||
TestCase tc = new TestCase(sd);
|
.validate(tc);
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDeConv3dBasic() {
|
||||||
|
int nIn = 4;
|
||||||
|
int nOut = 3;
|
||||||
|
int kH = 2;
|
||||||
|
int kW = 2;
|
||||||
|
int kD = 2;
|
||||||
|
|
||||||
|
int mb = 3;
|
||||||
|
int imgH = 5;
|
||||||
|
int imgW = 5;
|
||||||
|
int imgT = 5;
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
INDArray inArr = Nd4j.rand(new long[]{mb, nIn, 5, 5, 5});
|
||||||
|
INDArray wArr = Nd4j.rand(kD, kH, kW, nOut, nIn);
|
||||||
|
|
||||||
|
SDVariable in = sd.var("in", inArr);
|
||||||
|
SDVariable w = sd.var("W", wArr);
|
||||||
|
|
||||||
|
DeConv3DConfig conv3DConfig = DeConv3DConfig.builder()
|
||||||
|
.kH(kH).kW(kW).kD(kD)
|
||||||
|
.sD(1).sH(1).sW(1)
|
||||||
|
.dH(1).dW(1).dD(1)
|
||||||
|
.isSameMode(true)
|
||||||
|
.dataFormat(DeConv3DConfig.NCDHW)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig);
|
||||||
|
out = sd.nn().tanh("loss", out).shape().rename("out");
|
||||||
|
|
||||||
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
|
//Expected conv3d size, NOT same mode: out = (in - k)/d + 1 = (28-2+0)/1+1 = 27
|
||||||
|
//Expected conv3d size, WITH same mode: out = in/stride
|
||||||
|
// reversed this for deconv3d
|
||||||
|
INDArray outArr = Nd4j.createFromArray(new long[]{mb, nOut, imgT, imgH, imgW});
|
||||||
|
|
||||||
|
TestCase tc = new TestCase(sd)
|
||||||
|
.expectedOutput("out", outArr)
|
||||||
|
.gradientCheck(true);
|
||||||
String err = OpValidation.validate(tc);
|
String err = OpValidation.validate(tc);
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
@ -1181,23 +1276,23 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
|
||||||
for (boolean ncdhw : new boolean[]{true, false}) {
|
for (boolean ncdhw : new boolean[]{true, false}) {
|
||||||
int nIn = inSizeNCDHW[1];
|
int nIn = inSizeNCDHW[1];
|
||||||
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
|
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.var("in", shape);
|
SDVariable in = sd.var("in", shape);
|
||||||
|
|
||||||
SDVariable out;
|
SDVariable out;
|
||||||
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
|
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
|
||||||
|
|
||||||
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
|
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
|
||||||
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
|
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
|
||||||
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
|
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
|
||||||
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
|
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
|
||||||
.isSameMode(true)
|
.isSameMode(true)
|
||||||
.kH(2).kW(2).kD(2)
|
.kH(2).kW(2).kD(2)
|
||||||
.sD(1).sH(1).sW(-1).dW(-1)
|
.sD(1).sH(1).sW(-1).dW(-1)
|
||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,10 +38,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
|
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
||||||
|
@ -1008,7 +1008,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DifferentialFunction[] funcs = sd.functions();
|
DifferentialFunction[] funcs = sd.ops();
|
||||||
String name = opName == null ? funcs[0].opName() : opName;
|
String name = opName == null ? funcs[0].opName() : opName;
|
||||||
|
|
||||||
|
|
||||||
|
@ -1141,11 +1141,11 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
break;
|
break;
|
||||||
case 14:
|
case 14:
|
||||||
t = sd.max(in1, in2);
|
t = sd.max(in1, in2);
|
||||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldMax(ia, ib, ia.dup())));
|
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]);
|
||||||
break;
|
break;
|
||||||
case 15:
|
case 15:
|
||||||
t = sd.min(in1, in2);
|
t = sd.min(in1, in2);
|
||||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldMin(ia, ib, ia.dup())));
|
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
|
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
|
||||||
|
@ -1199,7 +1199,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DifferentialFunction[] funcs = sd.functions();
|
DifferentialFunction[] funcs = sd.ops();
|
||||||
String name = (opName == null ? funcs[0].opName() : opName);
|
String name = (opName == null ? funcs[0].opName() : opName);
|
||||||
|
|
||||||
String msg = "test: " + i + " - " + name;
|
String msg = "test: " + i + " - " + name;
|
||||||
|
|
|
@ -188,11 +188,11 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
||||||
assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName());
|
assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName());
|
||||||
}
|
}
|
||||||
|
|
||||||
DifferentialFunction[] fOrig = sd.functions();
|
DifferentialFunction[] fOrig = sd.ops();
|
||||||
DifferentialFunction[] fRestored = restored.functions();
|
DifferentialFunction[] fRestored = restored.ops();
|
||||||
assertEquals(fOrig.length, fRestored.length);
|
assertEquals(fOrig.length, fRestored.length);
|
||||||
|
|
||||||
for (int j = 0; j < sd.functions().length; j++) {
|
for (int j = 0; j < sd.ops().length; j++) {
|
||||||
assertEquals(fOrig[j].getClass(), fRestored[j].getClass());
|
assertEquals(fOrig[j].getClass(), fRestored[j].getClass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
||||||
sd.save(f2, withUpdaterState);
|
sd.save(f2, withUpdaterState);
|
||||||
SameDiff r2 = SameDiff.load(f2, withUpdaterState);
|
SameDiff r2 = SameDiff.load(f2, withUpdaterState);
|
||||||
assertEquals(varsOrig.size(), r2.variables().size());
|
assertEquals(varsOrig.size(), r2.variables().size());
|
||||||
assertEquals(fOrig.length, r2.functions().length);
|
assertEquals(fOrig.length, r2.ops().length);
|
||||||
assertEquals(sd.getLossVariables(), r2.getLossVariables());
|
assertEquals(sd.getLossVariables(), r2.getLossVariables());
|
||||||
|
|
||||||
//Save via stream:
|
//Save via stream:
|
||||||
|
@ -237,7 +237,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f3))) {
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f3))) {
|
||||||
SameDiff r3 = SameDiff.load(is, withUpdaterState);
|
SameDiff r3 = SameDiff.load(is, withUpdaterState);
|
||||||
assertEquals(varsOrig.size(), r3.variables().size());
|
assertEquals(varsOrig.size(), r3.variables().size());
|
||||||
assertEquals(fOrig.length, r3.functions().length);
|
assertEquals(fOrig.length, r3.ops().length);
|
||||||
assertEquals(sd.getLossVariables(), r3.getLossVariables());
|
assertEquals(sd.getLossVariables(), r3.getLossVariables());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ package org.nd4j.autodiff.samediff;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.autodiff.samediff.transform.*;
|
import org.nd4j.autodiff.samediff.transform.*;
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -58,17 +57,17 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
|
||||||
|
|
||||||
SDVariable sub = add.sub(add2);
|
SDVariable sub = add.sub(add2);
|
||||||
|
|
||||||
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(add.getVarName())));
|
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.getVarName())));
|
||||||
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(add2.getVarName())));
|
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.getVarName())));
|
||||||
assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputFunction(sub.getVarName())));
|
assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.getVarName())));
|
||||||
|
|
||||||
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(add.getVarName())));
|
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.getVarName())));
|
||||||
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(add2.getVarName())));
|
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.getVarName())));
|
||||||
assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputFunction(sub.getVarName())));
|
assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.getVarName())));
|
||||||
|
|
||||||
assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputFunction(add.getVarName())));
|
assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.getVarName())));
|
||||||
assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputFunction(add2.getVarName())));
|
assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.getVarName())));
|
||||||
assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputFunction(sub.getVarName())));
|
assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.getVarName())));
|
||||||
|
|
||||||
|
|
||||||
SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class));
|
SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class));
|
||||||
|
@ -77,11 +76,11 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
|
||||||
assertEquals(2, l.size());
|
assertEquals(2, l.size());
|
||||||
|
|
||||||
SubGraph sg1 = l.get(0);
|
SubGraph sg1 = l.get(0);
|
||||||
assertTrue(sg1.getRootNode() == sd.getVariableOutputFunction(add.getVarName()));
|
assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.getVarName()));
|
||||||
assertEquals(0, sg1.getChildNodes().size());
|
assertEquals(0, sg1.getChildNodes().size());
|
||||||
|
|
||||||
SubGraph sg2 = l.get(1);
|
SubGraph sg2 = l.get(1);
|
||||||
assertTrue(sg2.getRootNode() == sd.getVariableOutputFunction(add2.getVarName()));
|
assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.getVarName()));
|
||||||
assertEquals(0, sg2.getChildNodes().size());
|
assertEquals(0, sg2.getChildNodes().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,13 +59,13 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNorma
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||||
|
@ -1759,11 +1759,11 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
break;
|
break;
|
||||||
case 7:
|
case 7:
|
||||||
t = sd.max(in1, in2);
|
t = sd.max(in1, in2);
|
||||||
expOut = Nd4j.getExecutioner().exec(new OldMax(ia, ib, ia.dup()));
|
expOut = Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0];
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
t = sd.min(in1, in2);
|
t = sd.min(in1, in2);
|
||||||
expOut = Nd4j.getExecutioner().exec(new OldMin(ia, ib, ia.dup()));
|
expOut = Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0];
|
||||||
break;
|
break;
|
||||||
case 9:
|
case 9:
|
||||||
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
|
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package org.nd4j.imports.TFGraphs;
|
package org.nd4j.imports.TFGraphs;
|
||||||
|
|
||||||
import com.google.common.primitives.Doubles;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
|
@ -38,7 +37,6 @@ import org.nd4j.autodiff.validation.OpValidation;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
|
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -46,7 +44,6 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
||||||
import org.nd4j.linalg.function.BiFunction;
|
import org.nd4j.linalg.function.BiFunction;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
@ -301,7 +298,7 @@ public class TFGraphTestAllHelper {
|
||||||
Map<String,SameDiffOp> fns = graph.getOps();
|
Map<String,SameDiffOp> fns = graph.getOps();
|
||||||
List<String> execOrder = listener.getOpNamesList();
|
List<String> execOrder = listener.getOpNamesList();
|
||||||
for(String opName : execOrder){
|
for(String opName : execOrder){
|
||||||
String[] outputs = graph.getOutputsForFunction(fns.get(opName).getOp());
|
String[] outputs = graph.getOutputsForOp(fns.get(opName).getOp());
|
||||||
Collections.addAll(varNames, outputs);
|
Collections.addAll(varNames, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,8 +331,8 @@ public class TFGraphTestAllHelper {
|
||||||
if(countExceeds > 0){
|
if(countExceeds > 0){
|
||||||
maxRE = relError.maxNumber().doubleValue();
|
maxRE = relError.maxNumber().doubleValue();
|
||||||
//Find the op that this variable is produced by
|
//Find the op that this variable is produced by
|
||||||
op = graph.getVariableOutputFunction(varName);
|
op = graph.getVariableOutputOp(varName);
|
||||||
opInputs = graph.getInputsForFunction(op);
|
opInputs = graph.getInputsForOp(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -732,9 +732,9 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
val functions = new HashMap<String, DifferentialFunction>();
|
val functions = new HashMap<String, DifferentialFunction>();
|
||||||
for (val func: tg.functions()) {
|
for (val func: tg.ops()) {
|
||||||
val ownName = func.getOwnName();
|
val ownName = func.getOwnName();
|
||||||
val outName = func.outputVariables()[0].getVarName();
|
String outName = func.outputVariables()[0].getVarName();
|
||||||
|
|
||||||
assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
|
assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
|
||||||
assertEquals(ownName, outName);
|
assertEquals(ownName, outName);
|
||||||
|
|
|
@ -72,12 +72,12 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||||
|
@ -5226,7 +5226,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||||
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
|
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
|
|
||||||
INDArray rev = Nd4j.getExecutioner().exec(new OldReverse(array, Nd4j.createUninitialized(array.length())));
|
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0];
|
||||||
|
|
||||||
assertEquals(exp, rev);
|
assertEquals(exp, rev);
|
||||||
}
|
}
|
||||||
|
@ -5236,7 +5236,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||||
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
||||||
|
|
||||||
INDArray rev = Nd4j.getExecutioner().exec(new OldReverse(array, Nd4j.createUninitialized(array.length())));
|
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0];
|
||||||
|
|
||||||
assertEquals(exp, rev);
|
assertEquals(exp, rev);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
|
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -276,7 +276,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.INT);
|
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.INT);
|
||||||
val exp = new long[]{1, 0, 0, 1};
|
val exp = new long[]{1, 0, 0, 1};
|
||||||
|
|
||||||
val result = Nd4j.getExecutioner().exec(new OldEqualTo(arrayX, arrayY));
|
val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0];
|
||||||
assertEquals(DataType.BOOL, result.dataType());
|
assertEquals(DataType.BOOL, result.dataType());
|
||||||
val arr = result.data().asLong();
|
val arr = result.data().asLong();
|
||||||
|
|
||||||
|
@ -369,13 +369,13 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
val result = Nd4j.getExecutioner().exec(op);
|
val result = Nd4j.getExecutioner().exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = RuntimeException.class)
|
||||||
public void testTypesValidation_2() {
|
public void testTypesValidation_2() {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG);
|
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG);
|
||||||
val exp = new long[]{1, 0, 0, 1};
|
val exp = new long[]{1, 0, 0, 1};
|
||||||
|
|
||||||
val result = Nd4j.getExecutioner().exec(new OldEqualTo(arrayX, arrayY));
|
val result = Nd4j.getExecutioner().exec(new EqualTo(arrayX, arrayY, arrayX.ulike().castTo(DataType.BOOL)))[0];
|
||||||
val arr = result.data().asLong();
|
val arr = result.data().asLong();
|
||||||
|
|
||||||
assertArrayEquals(exp, arr);
|
assertArrayEquals(exp, arr);
|
||||||
|
|
|
@ -45,7 +45,7 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
||||||
|
@ -205,7 +205,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
INDArray x = Nd4j.ones(5);
|
INDArray x = Nd4j.ones(5);
|
||||||
INDArray xDup = x.dup();
|
INDArray xDup = x.dup();
|
||||||
INDArray solution = Nd4j.valueArrayOf(5, 1.0);
|
INDArray solution = Nd4j.valueArrayOf(5, 1.0);
|
||||||
opExecutioner.exec(new OldMulOp(x, xDup, x));
|
opExecutioner.exec(new MulOp(x, xDup, x));
|
||||||
assertEquals(solution, x);
|
assertEquals(solution, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
||||||
|
@ -236,7 +236,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
||||||
INDArray x = Nd4j.ones(5);
|
INDArray x = Nd4j.ones(5);
|
||||||
INDArray xDup = x.dup();
|
INDArray xDup = x.dup();
|
||||||
INDArray solution = Nd4j.valueArrayOf(5, 1.0);
|
INDArray solution = Nd4j.valueArrayOf(5, 1.0);
|
||||||
opExecutioner.exec(new OldMulOp(x, xDup, x));
|
opExecutioner.exec(new MulOp(x, xDup, x));
|
||||||
assertEquals(solution, x);
|
assertEquals(solution, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -72,8 +72,9 @@ public class PaddingTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPad() {
|
public void testPad() {
|
||||||
|
|
||||||
INDArray start = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
INDArray start = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
||||||
INDArray ret = Nd4j.pad(start, new int[] {5, 5}, Nd4j.PadMode.CONSTANT);
|
INDArray ret = Nd4j.pad(start, 5, 5);
|
||||||
double[][] data = new double[][] {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
|
double[][] data = new double[][] {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
|
||||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.},
|
||||||
|
|
|
@ -64,8 +64,7 @@ public class PaddingTestsC extends BaseNd4jTest {
|
||||||
INDArray ret = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
|
INDArray ret = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||||
4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
|
4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||||
4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8});
|
4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8});
|
||||||
INDArray padded = Nd4j.pad(ret, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}},
|
INDArray padded = Nd4j.pad(ret, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}});
|
||||||
Nd4j.PadMode.CONSTANT);
|
|
||||||
|
|
||||||
INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3,
|
INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3,
|
||||||
3, 3, 3, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0,
|
3, 3, 3, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0,
|
||||||
|
@ -104,8 +103,7 @@ public class PaddingTestsC extends BaseNd4jTest {
|
||||||
// FIXME: int cast
|
// FIXME: int cast
|
||||||
int outWidth = Convolution.outSize((int) h, kh, sy, ph, 1, true);
|
int outWidth = Convolution.outSize((int) h, kh, sy, ph, 1, true);
|
||||||
int outHeight = Convolution.outSize((int) w, kw, sx, pw, 1, true);
|
int outHeight = Convolution.outSize((int) w, kw, sx, pw, 1, true);
|
||||||
INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}},
|
INDArray padded = Nd4j.pad(linspaced, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}});
|
||||||
Nd4j.PadMode.CONSTANT);
|
|
||||||
System.out.println(padded);
|
System.out.println(padded);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -127,8 +127,7 @@ public class IndexingTestsC extends BaseNd4jTest {
|
||||||
4, 4, 4, 4, 4, 4, 4, 4}, new long[] {1, 1, 8, 8});
|
4, 4, 4, 4, 4, 4, 4, 4}, new long[] {1, 1, 8, 8});
|
||||||
|
|
||||||
|
|
||||||
INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}},
|
INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}});
|
||||||
Nd4j.PadMode.CONSTANT);
|
|
||||||
|
|
||||||
INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim),
|
INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim),
|
||||||
NDArrayIndex.interval(j, sx, jLim));
|
NDArrayIndex.interval(j, sx, jLim));
|
||||||
|
|
Loading…
Reference in New Issue