Uniform distribution op tweaks + 'specified output dtype' constructor (#38)

* Uniform distribution op tweaks + 'specified output dtype' constructor

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Validation tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-08 18:08:25 +11:00 committed by GitHub
parent 929c1dc5c7
commit 2f84ea666d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 20 deletions

View File

@ -401,8 +401,8 @@ public class DifferentialFunctionFactory {
return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables();
}
public SDVariable randomUniform(double min, double max, SDVariable shape) {
return new DistributionUniform(sameDiff(), shape, min, max).outputVariable();
public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) {
return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable();
}
public SDVariable randomUniform(double min, double max, long... shape) {

View File

@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.ops;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
@ -237,9 +238,23 @@ public class SDRandom extends SDOps {
return uniform(null, min, max, shape);
}
/**
* @see #uniform(String, double, double, SDVariable)
*/
public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
return uniform(null, min, max, shape, dataType);
}
/**
* As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
*/
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
return uniform(name, min, max, shape, null);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
* U(min,max)<br>
* U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned<br>
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
* specified as a long[] instead
*
@ -247,11 +262,12 @@ public class SDRandom extends SDOps {
* @param min Minimum value
* @param max Maximum value. Must satisfy max >= min
* @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable
* @param dataType Data type of the output array (if null: Float32 output is returned)
* @return New SDVariable, of the specified data type
*/
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
validateInteger("uniform random", shape);
SDVariable ret = f().randomUniform(min, max, shape);
SDVariable ret = f().randomUniform(min, max, shape, dataType);
return updateVariableNameAndReference(ret, name);
}

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -41,30 +42,47 @@ import java.util.Map;
public class DistributionUniform extends DynamicCustomOp {
private double min = 0.0;
private double max = 1.0;
private DataType dataType;
public DistributionUniform() {
//
}
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) {
this(sd, shape, min, max, null);
}
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max, DataType dataType){
super(null, sd, new SDVariable[]{shape});
Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max);
addTArgument(min, max);
Preconditions.checkState(dataType == null || dataType.isNumerical(), "Only numerical datatypes can be used with DistributionUniform - rquested output datatype: %s", dataType);
this.dataType = dataType;
this.min = min;
this.max = max;
addArgs();
}
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
this.min = min;
this.max = max;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
addArgs();
AttrValue v = attributesForNode.get("dtype");
dataType = TFGraphMapper.convertType(v.getType());
addIArgument(dataType.toInt());
}
protected void addArgs() {
tArguments.clear();
addTArgument(min, max);
if(dataType != null){
iArguments.clear();
addIArgument(dataType.toInt());
}
}
@Override
@ -85,8 +103,10 @@ public class DistributionUniform extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
//Input data type specifies the shape; output data type should be any float
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
//Input data type specifies the shape
if(dataType != null){
return Collections.singletonList(dataType);
}
return Collections.singletonList(DataType.FLOAT);
}
}

View File

@ -382,4 +382,21 @@ public class RandomOpValidation extends BaseOpValidation {
INDArray out = Nd4j.exec(all);
assertEquals(x, out);
}
@Test
public void testUniformDtype(){
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
SameDiff sd = SameDiff.create();
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
SDVariable out = sd.random.uniform(0, 10, shape, t);
INDArray arr = out.eval();
assertEquals(t, arr.dataType());
double min = arr.minNumber().doubleValue();
double max = arr.maxNumber().doubleValue();
double mean = arr.meanNumber().doubleValue();
assertEquals(0, min, 0.5);
assertEquals(10, max, 0.5);
assertEquals(5.5, mean, 1);
}
}
}

View File

@ -2311,7 +2311,6 @@ public class SameDiffTests extends BaseNd4jTest {
SDVariable loss = out.std("out", true);
INDArray outArr = loss.eval();
// sd.execBackwards(Collections.emptyMap());
Map<String,INDArray> grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
Map<String, INDArray> origGrad = new HashMap<>();
@ -2321,7 +2320,6 @@ public class SameDiffTests extends BaseNd4jTest {
in.getArr().assign(Nd4j.rand(in.getArr().shape()));
INDArray outArr2 = loss.eval();
// sd.execBackwards(Collections.emptyMap());
grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
assertNotEquals(outArr, outArr2);
@ -2641,8 +2639,7 @@ public class SameDiffTests extends BaseNd4jTest {
.expectedOutput("out", out)
.gradientCheck(true));
assertNull(err, err);
assertNull(err);
}
@Test