diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
index 21559d7f4..52e725191 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
@@ -401,8 +401,8 @@ public class DifferentialFunctionFactory {
return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables();
}
- public SDVariable randomUniform(double min, double max, SDVariable shape) {
- return new DistributionUniform(sameDiff(), shape, min, max).outputVariable();
+ public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) {
+ return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable();
}
public SDVariable randomUniform(double min, double max, long... shape) {
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java
index 463b412f3..66e52c151 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java
@@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.ops;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.buffer.DataType;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
@@ -237,21 +238,36 @@ public class SDRandom extends SDOps {
return uniform(null, min, max, shape);
}
+ /**
+ * @see #uniform(String, double, double, SDVariable)
+ */
+ public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
+ return uniform(null, min, max, shape, dataType);
+ }
+
+ /**
+ * As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
+ */
+ public SDVariable uniform(String name, double min, double max, SDVariable shape) {
+ return uniform(name, min, max, shape, null);
+ }
+
/**
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
- * U(min,max)
+ * U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
* specified as a long[] instead
*
- * @param name Name of the new SDVariable
- * @param min Minimum value
- * @param max Maximum value. Must satisfy max >= min
- * @param shape Shape of the new random SDVariable, as a 1D array
- * @return New SDVariable
+ * @param name Name of the new SDVariable
+ * @param min Minimum value
+ * @param max Maximum value. Must satisfy max >= min
+ * @param shape Shape of the new random SDVariable, as a 1D array
+ * @param dataType Data type of the output array (if null: Float32 output is returned)
+ * @return New SDVariable, of the specified data type
*/
- public SDVariable uniform(String name, double min, double max, SDVariable shape) {
+ public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
validateInteger("uniform random", shape);
- SDVariable ret = f().randomUniform(min, max, shape);
+ SDVariable ret = f().randomUniform(min, max, shape, dataType);
return updateVariableNameAndReference(ret, name);
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java
index 97330486f..0744533ba 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java
@@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
+import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@@ -41,30 +42,47 @@ import java.util.Map;
public class DistributionUniform extends DynamicCustomOp {
private double min = 0.0;
private double max = 1.0;
+ private DataType dataType;
public DistributionUniform() {
//
}
- public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max){
+ public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) {
+ this(sd, shape, min, max, null);
+ }
+
+ public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max, DataType dataType){
super(null, sd, new SDVariable[]{shape});
Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max);
- addTArgument(min, max);
+ Preconditions.checkState(dataType == null || dataType.isNumerical(), "Only numerical datatypes can be used with DistributionUniform - rquested output datatype: %s", dataType);
+ this.dataType = dataType;
+ this.min = min;
+ this.max = max;
+ addArgs();
}
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List)null);
+ this.min = min;
+ this.max = max;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
- super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
- addArgs();
+ AttrValue v = attributesForNode.get("dtype");
+ dataType = TFGraphMapper.convertType(v.getType());
+ addIArgument(dataType.toInt());
}
protected void addArgs() {
+ tArguments.clear();
addTArgument(min, max);
+ if(dataType != null){
+ iArguments.clear();
+ addIArgument(dataType.toInt());
+ }
}
@Override
@@ -85,8 +103,10 @@ public class DistributionUniform extends DynamicCustomOp {
@Override
public List calculateOutputDataTypes(List inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
- //Input data type specifies the shape; output data type should be any float
- //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854
+ //Input data type specifies the shape
+ if(dataType != null){
+ return Collections.singletonList(dataType);
+ }
return Collections.singletonList(DataType.FLOAT);
}
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java
index 8d64f6404..6c9633a41 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java
@@ -382,4 +382,21 @@ public class RandomOpValidation extends BaseOpValidation {
INDArray out = Nd4j.exec(all);
assertEquals(x, out);
}
+
+ @Test
+ public void testUniformDtype(){
+ for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
+ SameDiff sd = SameDiff.create();
+ SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100));
+ SDVariable out = sd.random.uniform(0, 10, shape, t);
+ INDArray arr = out.eval();
+ assertEquals(t, arr.dataType());
+ double min = arr.minNumber().doubleValue();
+ double max = arr.maxNumber().doubleValue();
+ double mean = arr.meanNumber().doubleValue();
+ assertEquals(0, min, 0.5);
+ assertEquals(10, max, 0.5);
+ assertEquals(5.5, mean, 1);
+ }
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
index 69c69388b..a045bdc99 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
@@ -2311,7 +2311,6 @@ public class SameDiffTests extends BaseNd4jTest {
SDVariable loss = out.std("out", true);
INDArray outArr = loss.eval();
-// sd.execBackwards(Collections.emptyMap());
Map grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
Map origGrad = new HashMap<>();
@@ -2321,7 +2320,6 @@ public class SameDiffTests extends BaseNd4jTest {
in.getArr().assign(Nd4j.rand(in.getArr().shape()));
INDArray outArr2 = loss.eval();
-// sd.execBackwards(Collections.emptyMap());
grads = sd.calculateGradients(null, in.name(), w.name(), out.name());
assertNotEquals(outArr, outArr2);
@@ -2641,8 +2639,7 @@ public class SameDiffTests extends BaseNd4jTest {
.expectedOutput("out", out)
.gradientCheck(true));
- assertNull(err, err);
-
+ assertNull(err);
}
@Test