Uniform distribution op tweaks + 'specified output dtype' constructor (#38)
* Uniform distribution op tweaks + 'specified output dtype' constructor Signed-off-by: AlexDBlack <blacka101@gmail.com> * Validation tweak Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
929c1dc5c7
commit
2f84ea666d
|
@ -401,8 +401,8 @@ public class DifferentialFunctionFactory {
|
||||||
return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables();
|
return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable randomUniform(double min, double max, SDVariable shape) {
|
public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) {
|
||||||
return new DistributionUniform(sameDiff(), shape, min, max).outputVariable();
|
return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable randomUniform(double min, double max, long... shape) {
|
public SDVariable randomUniform(double min, double max, long... shape) {
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||||
|
|
||||||
|
@ -237,9 +238,23 @@ public class SDRandom extends SDOps {
|
||||||
return uniform(null, min, max, shape);
|
return uniform(null, min, max, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @see #uniform(String, double, double, SDVariable)
|
||||||
|
*/
|
||||||
|
public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
|
||||||
|
return uniform(null, min, max, shape, dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
|
||||||
|
*/
|
||||||
|
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
|
||||||
|
return uniform(name, min, max, shape, null);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
|
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
|
||||||
* U(min,max)<br>
|
* U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned<br>
|
||||||
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
|
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
|
||||||
* specified as a long[] instead
|
* specified as a long[] instead
|
||||||
*
|
*
|
||||||
|
@ -247,11 +262,12 @@ public class SDRandom extends SDOps {
|
||||||
* @param min Minimum value
|
* @param min Minimum value
|
||||||
* @param max Maximum value. Must satisfy max >= min
|
* @param max Maximum value. Must satisfy max >= min
|
||||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||||
* @return New SDVariable
|
* @param dataType Data type of the output array (if null: Float32 output is returned)
|
||||||
|
* @return New SDVariable, of the specified data type
|
||||||
*/
|
*/
|
||||||
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
|
public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
|
||||||
validateInteger("uniform random", shape);
|
validateInteger("uniform random", shape);
|
||||||
SDVariable ret = f().randomUniform(min, max, shape);
|
SDVariable ret = f().randomUniform(min, max, shape, dataType);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -41,30 +42,47 @@ import java.util.Map;
|
||||||
public class DistributionUniform extends DynamicCustomOp {
|
public class DistributionUniform extends DynamicCustomOp {
|
||||||
private double min = 0.0;
|
private double min = 0.0;
|
||||||
private double max = 1.0;
|
private double max = 1.0;
|
||||||
|
private DataType dataType;
|
||||||
|
|
||||||
public DistributionUniform() {
|
public DistributionUniform() {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) {
|
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) {
|
||||||
|
this(sd, shape, min, max, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max, DataType dataType){
|
||||||
super(null, sd, new SDVariable[]{shape});
|
super(null, sd, new SDVariable[]{shape});
|
||||||
Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max);
|
Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max);
|
||||||
addTArgument(min, max);
|
Preconditions.checkState(dataType == null || dataType.isNumerical(), "Only numerical datatypes can be used with DistributionUniform - rquested output datatype: %s", dataType);
|
||||||
|
this.dataType = dataType;
|
||||||
|
this.min = min;
|
||||||
|
this.max = max;
|
||||||
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
|
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
|
||||||
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
|
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
|
||||||
|
this.min = min;
|
||||||
|
this.max = max;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
AttrValue v = attributesForNode.get("dtype");
|
||||||
addArgs();
|
dataType = TFGraphMapper.convertType(v.getType());
|
||||||
|
addIArgument(dataType.toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
|
tArguments.clear();
|
||||||
addTArgument(min, max);
|
addTArgument(min, max);
|
||||||
|
if(dataType != null){
|
||||||
|
iArguments.clear();
|
||||||
|
addIArgument(dataType.toInt());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -85,8 +103,10 @@ public class DistributionUniform extends DynamicCustomOp {
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||||
//Input data type specifies the shape; output data type should be any float
|
//Input data type specifies the shape
|
||||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
if(dataType != null){
|
||||||
|
return Collections.singletonList(dataType);
|
||||||
|
}
|
||||||
return Collections.singletonList(DataType.FLOAT);
|
return Collections.singletonList(DataType.FLOAT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -382,4 +382,21 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
INDArray out = Nd4j.exec(all);
|
INDArray out = Nd4j.exec(all);
|
||||||
assertEquals(x, out);
|
assertEquals(x, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUniformDtype(){
|
||||||
|
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
|
||||||
|
SDVariable out = sd.random.uniform(0, 10, shape, t);
|
||||||
|
INDArray arr = out.eval();
|
||||||
|
assertEquals(t, arr.dataType());
|
||||||
|
double min = arr.minNumber().doubleValue();
|
||||||
|
double max = arr.maxNumber().doubleValue();
|
||||||
|
double mean = arr.meanNumber().doubleValue();
|
||||||
|
assertEquals(0, min, 0.5);
|
||||||
|
assertEquals(10, max, 0.5);
|
||||||
|
assertEquals(5.5, mean, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2311,7 +2311,6 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
SDVariable loss = out.std("out", true);
|
SDVariable loss = out.std("out", true);
|
||||||
|
|
||||||
INDArray outArr = loss.eval();
|
INDArray outArr = loss.eval();
|
||||||
// sd.execBackwards(Collections.emptyMap());
|
|
||||||
Map<String,INDArray> grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
Map<String,INDArray> grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
||||||
|
|
||||||
Map<String, INDArray> origGrad = new HashMap<>();
|
Map<String, INDArray> origGrad = new HashMap<>();
|
||||||
|
@ -2321,7 +2320,6 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
|
|
||||||
in.getArr().assign(Nd4j.rand(in.getArr().shape()));
|
in.getArr().assign(Nd4j.rand(in.getArr().shape()));
|
||||||
INDArray outArr2 = loss.eval();
|
INDArray outArr2 = loss.eval();
|
||||||
// sd.execBackwards(Collections.emptyMap());
|
|
||||||
grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
||||||
|
|
||||||
assertNotEquals(outArr, outArr2);
|
assertNotEquals(outArr, outArr2);
|
||||||
|
@ -2641,8 +2639,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
.expectedOutput("out", out)
|
.expectedOutput("out", out)
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
|
|
||||||
assertNull(err, err);
|
assertNull(err);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue