Uniform distribution op tweaks + 'specified output dtype' constructor (#38)
* Uniform distribution op tweaks + 'specified output dtype' constructor Signed-off-by: AlexDBlack <blacka101@gmail.com> * Validation tweak Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
929c1dc5c7
commit
2f84ea666d
|
@ -401,8 +401,8 @@ public class DifferentialFunctionFactory {
|
|||
return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable randomUniform(double min, double max, SDVariable shape) {
|
||||
return new DistributionUniform(sameDiff(), shape, min, max).outputVariable();
|
||||
public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) {
|
||||
return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable randomUniform(double min, double max, long... shape) {
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.ops;
|
|||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||
|
||||
|
@ -237,21 +238,36 @@ public class SDRandom extends SDOps {
|
|||
return uniform(null, min, max, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #uniform(String, double, double, SDVariable)
|
||||
*/
|
||||
public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
|
||||
return uniform(null, min, max, shape, dataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
|
||||
return uniform(name, min, max, shape, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
|
||||
* U(min,max)<br>
|
||||
* U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned<br>
|
||||
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
|
||||
* specified as a long[] instead
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value. Must satisfy max >= min
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @return New SDVariable
|
||||
* @param name Name of the new SDVariable
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value. Must satisfy max >= min
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @param dataType Data type of the output array (if null: Float32 output is returned)
|
||||
* @return New SDVariable, of the specified data type
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
|
||||
validateInteger("uniform random", shape);
|
||||
SDVariable ret = f().randomUniform(min, max, shape);
|
||||
SDVariable ret = f().randomUniform(min, max, shape, dataType);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -41,30 +42,47 @@ import java.util.Map;
|
|||
public class DistributionUniform extends DynamicCustomOp {
|
||||
private double min = 0.0;
|
||||
private double max = 1.0;
|
||||
private DataType dataType;
|
||||
|
||||
public DistributionUniform() {
|
||||
//
|
||||
}
|
||||
|
||||
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max){
|
||||
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) {
|
||||
this(sd, shape, min, max, null);
|
||||
}
|
||||
|
||||
public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max, DataType dataType){
|
||||
super(null, sd, new SDVariable[]{shape});
|
||||
Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max);
|
||||
addTArgument(min, max);
|
||||
Preconditions.checkState(dataType == null || dataType.isNumerical(), "Only numerical datatypes can be used with DistributionUniform - rquested output datatype: %s", dataType);
|
||||
this.dataType = dataType;
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
|
||||
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
||||
addArgs();
|
||||
AttrValue v = attributesForNode.get("dtype");
|
||||
dataType = TFGraphMapper.convertType(v.getType());
|
||||
addIArgument(dataType.toInt());
|
||||
}
|
||||
|
||||
protected void addArgs() {
|
||||
tArguments.clear();
|
||||
addTArgument(min, max);
|
||||
if(dataType != null){
|
||||
iArguments.clear();
|
||||
addIArgument(dataType.toInt());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -85,8 +103,10 @@ public class DistributionUniform extends DynamicCustomOp {
|
|||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||
//Input data type specifies the shape; output data type should be any float
|
||||
//TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
|
||||
//Input data type specifies the shape
|
||||
if(dataType != null){
|
||||
return Collections.singletonList(dataType);
|
||||
}
|
||||
return Collections.singletonList(DataType.FLOAT);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -382,4 +382,21 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
INDArray out = Nd4j.exec(all);
|
||||
assertEquals(x, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUniformDtype(){
|
||||
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
|
||||
SDVariable out = sd.random.uniform(0, 10, shape, t);
|
||||
INDArray arr = out.eval();
|
||||
assertEquals(t, arr.dataType());
|
||||
double min = arr.minNumber().doubleValue();
|
||||
double max = arr.maxNumber().doubleValue();
|
||||
double mean = arr.meanNumber().doubleValue();
|
||||
assertEquals(0, min, 0.5);
|
||||
assertEquals(10, max, 0.5);
|
||||
assertEquals(5.5, mean, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2311,7 +2311,6 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
SDVariable loss = out.std("out", true);
|
||||
|
||||
INDArray outArr = loss.eval();
|
||||
// sd.execBackwards(Collections.emptyMap());
|
||||
Map<String,INDArray> grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
||||
|
||||
Map<String, INDArray> origGrad = new HashMap<>();
|
||||
|
@ -2321,7 +2320,6 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
|
||||
in.getArr().assign(Nd4j.rand(in.getArr().shape()));
|
||||
INDArray outArr2 = loss.eval();
|
||||
// sd.execBackwards(Collections.emptyMap());
|
||||
grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
|
||||
|
||||
assertNotEquals(outArr, outArr2);
|
||||
|
@ -2641,8 +2639,7 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
.expectedOutput("out", out)
|
||||
.gradientCheck(true));
|
||||
|
||||
assertNull(err, err);
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue