diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 21559d7f4..52e725191 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -401,8 +401,8 @@ public class DifferentialFunctionFactory { return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables(); } - public SDVariable randomUniform(double min, double max, SDVariable shape) { - return new DistributionUniform(sameDiff(), shape, min, max).outputVariable(); + public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) { + return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable(); } public SDVariable randomUniform(double min, double max, long... shape) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index 463b412f3..66e52c151 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.ops; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; @@ -237,21 +238,36 @@ public class SDRandom extends SDOps { return uniform(null, min, max, shape); } + /** + * @see #uniform(String, double, double, SDVariable) + */ + public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) { + return uniform(null, min, max, shape, dataType); + } + + /** + * As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output + */ + public SDVariable uniform(String name, double min, double max, SDVariable shape) { + return uniform(name, min, max, shape, null); + } + /** * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, - * U(min,max)
+ * U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is * specified as a long[] instead * - * @param name Name of the new SDVariable - * @param min Minimum value - * @param max Maximum value. Must satisfy max >= min - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable + * @param name Name of the new SDVariable + * @param min Minimum value + * @param max Maximum value. Must satisfy max >= min + * @param shape Shape of the new random SDVariable, as a 1D array + * @param dataType Data type of the output array (if null: Float32 output is returned) + * @return New SDVariable, of the specified data type */ - public SDVariable uniform(String name, double min, double max, SDVariable shape) { + public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) { validateInteger("uniform random", shape); - SDVariable ret = f().randomUniform(min, max, shape); + SDVariable ret = f().randomUniform(min, max, shape, dataType); return updateVariableNameAndReference(ret, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java index 97330486f..0744533ba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -41,30 +42,47 @@ import java.util.Map; public class DistributionUniform extends DynamicCustomOp { private double min = 0.0; private double max = 1.0; + private DataType dataType; public DistributionUniform() { // } - public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max){ + public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max) { + this(sd, shape, min, max, null); + } + + public DistributionUniform(SameDiff sd, SDVariable shape, double min, double max, DataType dataType){ super(null, sd, new SDVariable[]{shape}); Preconditions.checkState(min <= max, "Minimum (%s) must be <= max (%s)", min, max); - addTArgument(min, max); + Preconditions.checkState(dataType == null || dataType.isNumerical(), "Only numerical datatypes can be used with DistributionUniform - rquested output datatype: %s", dataType); + this.dataType = dataType; + this.min = min; + this.max = max; + addArgs(); } public DistributionUniform(INDArray shape, INDArray out, double min, double max){ super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List)null); + this.min = min; + this.max = max; } @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); - addArgs(); + AttrValue v = attributesForNode.get("dtype"); + dataType = TFGraphMapper.convertType(v.getType()); + addIArgument(dataType.toInt()); } protected void addArgs() { + tArguments.clear(); addTArgument(min, max); + if(dataType != null){ + iArguments.clear(); + addIArgument(dataType.toInt()); + } } @Override @@ -85,8 +103,10 @@ public class DistributionUniform extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); - //Input data type specifies the shape; output data type should be any float - //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 + //Input data type specifies the shape + if(dataType != null){ + return Collections.singletonList(dataType); + } return Collections.singletonList(DataType.FLOAT); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 8d64f6404..6c9633a41 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -382,4 +382,21 @@ public class RandomOpValidation extends BaseOpValidation { INDArray out = Nd4j.exec(all); assertEquals(x, out); } + + @Test + public void testUniformDtype(){ + for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ + SameDiff sd = SameDiff.create(); + SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100)); + SDVariable out = sd.random.uniform(0, 10, shape, t); + INDArray arr = out.eval(); + assertEquals(t, arr.dataType()); + double min = arr.minNumber().doubleValue(); + double max = arr.maxNumber().doubleValue(); + double mean = arr.meanNumber().doubleValue(); + assertEquals(0, min, 0.5); + assertEquals(10, max, 0.5); + assertEquals(5.5, mean, 1); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 69c69388b..a045bdc99 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -2311,7 +2311,6 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable loss = out.std("out", true); INDArray outArr = loss.eval(); -// sd.execBackwards(Collections.emptyMap()); Map grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); Map origGrad = new HashMap<>(); @@ -2321,7 +2320,6 @@ public class SameDiffTests extends BaseNd4jTest { in.getArr().assign(Nd4j.rand(in.getArr().shape())); INDArray outArr2 = loss.eval(); -// sd.execBackwards(Collections.emptyMap()); grads = sd.calculateGradients(null, in.name(), w.name(), out.name()); assertNotEquals(outArr, outArr2); @@ -2641,8 +2639,7 @@ public class SameDiffTests extends BaseNd4jTest { .expectedOutput("out", out) .gradientCheck(true)); - assertNull(err, err); - + assertNull(err); } @Test